From 74e413f2c41750954c7a59ecf8cb3e73977654a5 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 1 Feb 2023 12:57:47 +0300 Subject: [PATCH 01/37] add ScatterElements converter to ONNX front-end --- python/tvm/relay/frontend/onnx.py | 59 ++++++++++++++++++++++++++++++- 1 file changed, 58 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index aebc6daa5ebe..95590fe5e2ce 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2853,6 +2853,63 @@ def _impl_v1(cls, inputs, attr, params): return _op.scatter(inputs[0], inputs[1], inputs[2], axis) +class ScatterElements(OnnxOpConverter): + """Operator converter for ScatterElements.""" + + @classmethod + def _args_check(cls, inputs, attr, red_valids=None): + ret = [] + assert ( + len(inputs) == 3 + ), "ScatterElements takes 3 inputs (data, indices, updates), {} given".format(len(inputs)) + assert infer_type(inputs[1]).checked_type.dtype in ["int32", "int64"] + + axis = attr.get("axis", 0) + rank = len(infer_shape(inputs[0])) + assert rank > 0, "Data rank higher than 0 is expected" + assert -rank <= axis < rank, "Axis is out of bounds" + ret.append(axis) + + if red_valids: + reduction = attr.get("reduction", None) + assert reduction in red_valids + ret.append(reduction) + + return ret + + @classmethod + def _impl_v11(cls, inputs, attr, params): + axis = cls._args_check(inputs, attr) + + return _op.scatter(inputs[0], inputs[1], inputs[2], axis) + + @classmethod + def _impl_v16(cls, inputs, attr, params): + axis, reduction = cls._args_check(inputs, attr, [None, "add", "mul"]) + + if not reduction: + return _op.scatter(inputs[0], inputs[1], inputs[2], axis) + elif reduction == "add": + return _op.scatter_add(inputs[0], inputs[1], inputs[2], axis) + # else: # There is no other choices due to it was checked earlier + # return _op.scatter_mul(inputs[0], inputs[1], inputs[2], axis) + + @classmethod + def _impl_v18(cls, inputs, attr, params): + axis, reduction = cls._args_check(inputs, attr, [None, "add", "mul", "min", "max"]) + + if not reduction: + return _op.scatter(inputs[0], inputs[1], inputs[2], axis) + elif reduction == "add": + return _op.scatter_add(inputs[0], inputs[1], inputs[2], axis) + # elif reduction == "mul": + # return _op.scatter_mul(inputs[0], inputs[1], inputs[2], axis) + # elif reduction == "min": + # return _op.scatter_min(inputs[0], inputs[1], inputs[2], axis) + # else: # There is no other choices due to it was checked earlier + # return _op.scatter_max(inputs[0], inputs[1], inputs[2], axis) + + class ScatterND(OnnxOpConverter): """Operator converter for ScatterND.""" @@ -6588,7 +6645,7 @@ def _get_convert_map(opset): "Compress": Compress.get_converter(opset), "Size": AttrCvt("ndarray_size", extras={"dtype": "int64"}), "Scatter": Scatter.get_converter(opset), - "ScatterElements": Scatter.get_converter(opset), + "ScatterElements": ScatterElements.get_converter(opset), "ScatterND": ScatterND.get_converter(opset), "EyeLike": EyeLike.get_converter(opset), "Squeeze": Squeeze.get_converter(opset), From eafbe11830ee90e985f0196b5000fd4b436d4ed3 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 2 Feb 2023 14:42:23 +0300 Subject: [PATCH 02/37] native front-end for ScatterElements was implemented --- include/tvm/relay/attrs/transform.h | 12 +++++++ python/tvm/relay/op/transform.py | 28 +++++++++++++++ src/relay/op/tensor/transform.cc | 56 +++++++++++++++++++++++++++++ 3 files changed, 96 insertions(+) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 82403d7c40ee..fe8008ed6069 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -164,6 +164,18 @@ struct ScatterAddAttrs : public tvm::AttrsNode { } }; +struct ScatterElementsAttrs : public tvm::AttrsNode { + Integer axis; + String reduction; + + TVM_DECLARE_ATTRS(ScatterElementsAttrs, "relay.attrs.ScatterElementsAttrs") { + TVM_ATTR_FIELD(axis).set_default(0).describe("The axis over which to select values."); + TVM_ATTR_FIELD(reduction).set_default("update").describe( + "Reduction mode of the scatter elements, " + "either \"update\" or \"add\" or \"mul\" or \"min\" or \"max\"."); + } +}; + struct ScatterNDAttrs : public tvm::AttrsNode { String mode; diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 782797dadb83..0621c40c0081 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -403,6 +403,34 @@ def scatter_add(data, indices, updates, axis): return _make.scatter_add(data, indices, updates, axis) +def scatter_elements(data, indices, updates, axis, reduction): + """Update data by adding values in updates at positions defined by indices. + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + indices : relay.Expr + The index locations to update. + + updates : relay.Expr + The values to add. + + axis : int + The axis to scatter_add on. + + reduction : string, optional + The reduction mode for scatter. Choise is from ["update", "add", "mul", "min", max"] + + Returns + ------- + ret : relay.Expr + The computed result. + """ + return _make.scatter_elements(data, indices, updates, axis, reduction) + + def scatter_nd(data, indices, updates, mode="update"): """Scatter values from an array and update. diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 5c5cd6f4b721..13a1acf15a95 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1191,6 +1191,62 @@ RELAY_REGISTER_OP("scatter_add") .set_attr("TOpPattern", kOpaque) .set_support_level(10); +// scatter_elements operator +TVM_REGISTER_NODE_TYPE(ScatterElementsAttrs); + +bool ScatterElementsRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // `types` contains: [data, indices, updates, output] + ICHECK_EQ(types.size(), 4); + const auto* data = types[0].as(); + const auto* indices = types[1].as(); + const auto* updates = types[2].as(); + if (data == nullptr) { + ICHECK(types[0].as()) + << "ScatterElements: expect input data type to be TensorType but got " << types[0]; + return false; + } + if (indices == nullptr) { + ICHECK(types[1].as()) + << "ScatterElements: expect indices type to be TensorType but got " << types[1]; + return false; + } + if (updates == nullptr) { + ICHECK(types[2].as()) + << "ScatterElements: expect updates type to be TensorType but got " << types[2]; + return false; + } + // TODO(vvchernov): ONNX requires int32 and int64 + ICHECK(indices->dtype.is_int() || indices->dtype.is_uint()) + << "ScatterElements: indices must be a tensor of integers."; + + // Assign output + reporter->Assign(types[3], TensorType(data->shape, data->dtype)); + return true; +} + +Expr MakeScatterElements(Expr data, Expr indices, Expr updates, int axis, String reduction) { + auto attrs = make_object(); + attrs->axis = std::move(axis); + attrs->reduction = std::move(reduction); + static const Op& op = Op::Get("scatter_elements"); + return Call(op, {data, indices, updates}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.scatter_elements").set_body_typed(MakeScatterElements); + +// scatter_elements op has extern schedules: convert to Opaque to prevent compilation failures +RELAY_REGISTER_OP("scatter_elements") + .describe(R"code(Scatter elements with updating data by reduction values in updates +at positions defined by indices.)code" TVM_ADD_FILELINE) + .set_num_inputs(3) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("indices", "Tensor", "The indices tensor.") + .add_argument("updates", "Tensor", "The input tensor of updates.") + .set_support_level(3) + .add_type_rel("ScatterElements", ScatterElementsRel) + .set_attr("TOpPattern", kOpaque); + // scatter_nd operator TVM_REGISTER_NODE_TYPE(ScatterNDAttrs); From 12cf5322b20103a7472aaff4eeaed1aa6db18610 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 2 Feb 2023 15:00:08 +0300 Subject: [PATCH 03/37] update ScatterElements in ONNX high-level front-end --- python/tvm/relay/frontend/onnx.py | 26 +++++++------------------- 1 file changed, 7 insertions(+), 19 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 95590fe5e2ce..a5a0b14d5434 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2872,6 +2872,8 @@ def _args_check(cls, inputs, attr, red_valids=None): if red_valids: reduction = attr.get("reduction", None) + if reduction is None: + reduction = "update" assert reduction in red_valids ret.append(reduction) @@ -2885,29 +2887,15 @@ def _impl_v11(cls, inputs, attr, params): @classmethod def _impl_v16(cls, inputs, attr, params): - axis, reduction = cls._args_check(inputs, attr, [None, "add", "mul"]) + axis, reduction = cls._args_check(inputs, attr, ["update", "add", "mul"]) - if not reduction: - return _op.scatter(inputs[0], inputs[1], inputs[2], axis) - elif reduction == "add": - return _op.scatter_add(inputs[0], inputs[1], inputs[2], axis) - # else: # There is no other choices due to it was checked earlier - # return _op.scatter_mul(inputs[0], inputs[1], inputs[2], axis) + return _op.scatter_elements(inputs[0], inputs[1], inputs[2], axis, reduction) @classmethod def _impl_v18(cls, inputs, attr, params): - axis, reduction = cls._args_check(inputs, attr, [None, "add", "mul", "min", "max"]) - - if not reduction: - return _op.scatter(inputs[0], inputs[1], inputs[2], axis) - elif reduction == "add": - return _op.scatter_add(inputs[0], inputs[1], inputs[2], axis) - # elif reduction == "mul": - # return _op.scatter_mul(inputs[0], inputs[1], inputs[2], axis) - # elif reduction == "min": - # return _op.scatter_min(inputs[0], inputs[1], inputs[2], axis) - # else: # There is no other choices due to it was checked earlier - # return _op.scatter_max(inputs[0], inputs[1], inputs[2], axis) + axis, reduction = cls._args_check(inputs, attr, ["update", "add", "mul", "min", "max"]) + + return _op.scatter_elements(inputs[0], inputs[1], inputs[2], axis, reduction) class ScatterND(OnnxOpConverter): From 64656a72a98b161d3db8291281217fb5e4c8d784 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 3 Feb 2023 09:24:46 +0300 Subject: [PATCH 04/37] update comments --- python/tvm/relay/op/transform.py | 8 +++++--- src/relay/op/tensor/transform.cc | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 0621c40c0081..1f97d21945bd 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -403,8 +403,9 @@ def scatter_add(data, indices, updates, axis): return _make.scatter_add(data, indices, updates, axis) -def scatter_elements(data, indices, updates, axis, reduction): - """Update data by adding values in updates at positions defined by indices. +def scatter_elements(data, indices, updates, axis, reduction="update"): + """Scatter elements with updating data by reduction of values in updates + at positions defined by indices. Parameters ---------- @@ -421,7 +422,8 @@ def scatter_elements(data, indices, updates, axis, reduction): The axis to scatter_add on. reduction : string, optional - The reduction mode for scatter. Choise is from ["update", "add", "mul", "min", max"] + The reduction mode for scatter. Choise is from ["update", "add", "mul", "min", max"]. + It is "update" by default. Returns ------- diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 13a1acf15a95..01e5a7f5f359 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1237,7 +1237,7 @@ TVM_REGISTER_GLOBAL("relay.op._make.scatter_elements").set_body_typed(MakeScatte // scatter_elements op has extern schedules: convert to Opaque to prevent compilation failures RELAY_REGISTER_OP("scatter_elements") - .describe(R"code(Scatter elements with updating data by reduction values in updates + .describe(R"code(Scatter elements with updating data by reduction of values in updates at positions defined by indices.)code" TVM_ADD_FILELINE) .set_num_inputs(3) .add_argument("data", "Tensor", "The input tensor.") From 90a4b305a2f36f09ba0dfb7836dfe07091f0e137 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 3 Feb 2023 09:49:54 +0300 Subject: [PATCH 05/37] register ScatterElementsAttrs --- python/tvm/relay/op/op_attrs.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index b76097722c07..ea7c415b511f 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -639,6 +639,11 @@ class ScatterAddAttrs(Attrs): """Attributes used in scatter_add operators""" +@tvm._ffi.register_object("relay.attrs.ScatterElementsAttrs") +class ScatterElementsAttrs(Attrs): + """Attributes used in scatter_elements operators""" + + @tvm._ffi.register_object("relay.attrs.ScatterNDAttrs") class ScatterNDAttrs(Attrs): """Attributes used in scatter_nd operators""" From 709b74e5c449d4e8027e237f452f0e0e90f36ec8 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 3 Feb 2023 09:59:15 +0300 Subject: [PATCH 06/37] register scatter elements strategy --- python/tvm/relay/op/_transform.py | 10 ++++++++++ python/tvm/relay/op/strategy/generic.py | 22 ++++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index d4e4a527835a..e40179ed2d03 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -204,6 +204,15 @@ def compute_scatter_add(attrs, inputs, output_type): _reg.register_strategy("scatter_add", strategy.scatter_add_strategy) +# scatter_elements +@_reg.register_compute("scatter_elements") +def compute_scatter_elements(attrs, inputs, output_type): + """Compute definition of scatter_elements""" + return [topi.scatter_elements(inputs[0], inputs[1], inputs[2], attrs.axis, attrs.reduction)] + + +_reg.register_strategy("scatter_elements", strategy.scatter_elements_strategy) + # scatter_nd @_reg.register_compute("scatter_nd") def compute_scatter_nd(attrs, inputs, output_type): @@ -679,6 +688,7 @@ def argwhere_shape_func(attrs, inputs, out_ndims): _reg.register_shape_func("scatter", False, elemwise_shape_func) _reg.register_shape_func("scatter_add", False, elemwise_shape_func) +_reg.register_shape_func("scatter_elements", False, elemwise_shape_func) _reg.register_shape_func("scatter_nd", False, elemwise_shape_func) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 1cf55f7145cd..4e0448f1799b 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1580,6 +1580,28 @@ def scatter_add_strategy(attrs, outs, out_type, target): return strategy +# scatter_elements +@override_native_generic_func("scatter_elements_strategy") +def scatter_elements_strategy(attrs, inputs, out_type, target): + """scatter_elements generic strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_scatter_elements(topi.scatter_elements), + wrap_topi_schedule(topi.generic.schedule_extern), + name="scatter_elements.generic", + ) + return strategy + + +def wrap_compute_scatter_elements(topi_compute): + """Wrap scatter_elements topi compute""" + + def _compute_scatter_elements(attrs, inputs, _): + return [topi_compute(inputs[0], inputs[1], inputs[2], attrs.axis, attrs.reduction)] + + return _compute_scatter_elements + + # scatter_nd @override_native_generic_func("scatter_nd_strategy") def scatter_nd_strategy(attrs, inputs, out_type, target): From 1327f7df24ce4ebd4b3de3d038254fa3e0071946 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 3 Feb 2023 15:31:12 +0300 Subject: [PATCH 07/37] implement generic scatter elements in topi --- python/tvm/topi/__init__.py | 1 + python/tvm/topi/scatter_elements.py | 126 ++++++++++++++++++++++++++++ 2 files changed, 127 insertions(+) create mode 100644 python/tvm/topi/scatter_elements.py diff --git a/python/tvm/topi/__init__.py b/python/tvm/topi/__init__.py index 785ba395d25b..75867136e09e 100644 --- a/python/tvm/topi/__init__.py +++ b/python/tvm/topi/__init__.py @@ -38,6 +38,7 @@ from .broadcast import * from .sort import * from .scatter import * +from .scatter_elements import * from .sparse_fill_empty_rows import * from .sparse_reshape import * from .scatter_add import * diff --git a/python/tvm/topi/scatter_elements.py b/python/tvm/topi/scatter_elements.py new file mode 100644 index 000000000000..cb49d375228c --- /dev/null +++ b/python/tvm/topi/scatter_elements.py @@ -0,0 +1,126 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=redefined-builtin, invalid-name, too-many-arguments, too-many-nested-blocks +"""ScatterElements operator""" +from ..te import extern +from ..tir import min, max, decl_buffer, ir_builder + + +def scatter_elements(data, indices, updates, axis=0, reduction="update"): + """Scatter elements from updates to corresponding indices of copied data. + + Data, indices, updates and output have the same shape. + Indices can not have duplicates (if idx1 != idx2, then indices[idx1] != indices[idx2]) + if reduction == "update". + + .. code-block:: + + output[indices[i][j]][j] = f(output[indices[i][j]][j], updates[i][j]) if axis = 0, + output[i][indices[i][j]] = f(output[i][indices[i][j]], updates[i][j]) if axis = 1, + + where the update function f is determinted by the reduction. + Five types of the function are supported: replace, +, *, min or max + + Parameters + ---------- + data : tvm.te.Tensor + The source array. + + indices : tvm.te.Tensor + The indices of the values to extract. + + updates : tvm.te.Tensor + The updates to apply at the Indices + + axis : optional, int + The axis to scatter on. It is zero by default. + + reduction : optional, string + The update mode for the algorithm, either "update", "add", "mul", "min" or "max" + If update, the update values will replace the input data + If add, the update values will be added to the input data + If mul, the update values will be multiply to the input data + If min, there is choice of minimal between the update values and the input data + If max, there is choice of maximal between the update values and the input data + It is "update" by default + + Returns + ------- + ret : tvm.te.Tensor + """ + + def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): + # pylint: disable=invalid-name + ib = ir_builder.create() + + data = ib.buffer_ptr(data_ptr) + indices = ib.buffer_ptr(indices_ptr) + updates = ib.buffer_ptr(updates_ptr) + out = ib.buffer_ptr(out_ptr) + + # Prepare ranges and strides + before_axis_range = 1 + for i in data_ptr.shape[:axis]: + before_axis_range *= i + + axis_range = data_ptr.shape[axis] + + after_axis_range = 1 + for i in data_ptr.shape[axis + 1 :]: + after_axis_range *= i + before_axis_stride = axis_range * after_axis_range + + # Copy initial input data to output + fused_shape = before_axis_range * before_axis_stride + + with ib.for_range(0, fused_shape) as i: + out[i] = data[i] + + with ib.for_range(0, before_axis_range, kind="parallel") as i: + with ib.for_range(0, after_axis_range, kind="parallel") as j: + with ib.for_range(0, axis_range, kind="parallel") as k: + pre_index = i * before_axis_stride + j + index1 = pre_index + k * after_axis_range + index2 = pre_index + indices[index1] + if reduction == "update": + out[index2] = updates[index1] + elif reduction == "add": + out[index2] += updates[index1] + elif reduction == "mul": + out[index2] *= updates[index1] + elif reduction == "min": + min(out[index2], updates[index1]) + elif reduction == "max": + max(out[index2], updates[index1]) + else: + raise NotImplementedError( + "scatter_elements reduction not in [update, add, mul, min, max]:", + reduction, + ) + + return ib.get() + + out_buf = decl_buffer(data.shape, data.dtype, "out_buf") + return extern( + [data.shape], + [data, indices, updates], + lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]), + dtype=data.dtype, + out_buffers=[out_buf], + name="scatter_elements.generic", + tag="scatter_elements.generic", + ) From f16648c35b4adf604aca5bc992eaec60c8ce73d8 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Sat, 4 Feb 2023 11:00:01 +0300 Subject: [PATCH 08/37] fix min-max redefinition --- python/tvm/topi/scatter_elements.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/python/tvm/topi/scatter_elements.py b/python/tvm/topi/scatter_elements.py index cb49d375228c..ccac226013ab 100644 --- a/python/tvm/topi/scatter_elements.py +++ b/python/tvm/topi/scatter_elements.py @@ -14,10 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=redefined-builtin, invalid-name, too-many-arguments, too-many-nested-blocks """ScatterElements operator""" -from ..te import extern -from ..tir import min, max, decl_buffer, ir_builder +from tvm import te +from tvm import tir def scatter_elements(data, indices, updates, axis=0, reduction="update"): @@ -65,7 +64,7 @@ def scatter_elements(data, indices, updates, axis=0, reduction="update"): def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): # pylint: disable=invalid-name - ib = ir_builder.create() + ib = tir.ir_builder.create() data = ib.buffer_ptr(data_ptr) indices = ib.buffer_ptr(indices_ptr) @@ -103,9 +102,9 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): elif reduction == "mul": out[index2] *= updates[index1] elif reduction == "min": - min(out[index2], updates[index1]) + tir.min(out[index2], updates[index1]) elif reduction == "max": - max(out[index2], updates[index1]) + tir.max(out[index2], updates[index1]) else: raise NotImplementedError( "scatter_elements reduction not in [update, add, mul, min, max]:", @@ -114,8 +113,8 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): return ib.get() - out_buf = decl_buffer(data.shape, data.dtype, "out_buf") - return extern( + out_buf = tir.decl_buffer(data.shape, data.dtype, "out_buf") + return te.extern( [data.shape], [data, indices, updates], lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]), From 39f4fc0cae025ba1c00e056cc120353ed1bb92c6 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Sat, 4 Feb 2023 15:41:14 +0300 Subject: [PATCH 09/37] fix IntImm conversion and update scatter element implementation --- include/tvm/relay/attrs/transform.h | 2 +- python/tvm/topi/scatter_elements.py | 44 ++++++++++++++++------------- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index fe8008ed6069..b5333961ebf9 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -172,7 +172,7 @@ struct ScatterElementsAttrs : public tvm::AttrsNode { TVM_ATTR_FIELD(axis).set_default(0).describe("The axis over which to select values."); TVM_ATTR_FIELD(reduction).set_default("update").describe( "Reduction mode of the scatter elements, " - "either \"update\" or \"add\" or \"mul\" or \"min\" or \"max\"."); + "either \"update\", \"add\", \"mul\", \"min\" or \"max\"."); } }; diff --git a/python/tvm/topi/scatter_elements.py b/python/tvm/topi/scatter_elements.py index ccac226013ab..40ceac383ab5 100644 --- a/python/tvm/topi/scatter_elements.py +++ b/python/tvm/topi/scatter_elements.py @@ -17,6 +17,7 @@ """ScatterElements operator""" from tvm import te from tvm import tir +from . import utils def scatter_elements(data, indices, updates, axis=0, reduction="update"): @@ -61,6 +62,25 @@ def scatter_elements(data, indices, updates, axis=0, reduction="update"): ------- ret : tvm.te.Tensor """ + if not isinstance(axis, int): + axis = utils.get_const_int(axis) + + shape = data.shape + axis_range = shape[axis] + + if axis < 0: + axis = len(shape) + axis + + # Prepare ranges and strides + before_axis_range = 1 + after_axis_range = 1 + for i, value in enumerate(shape, 0): + if i < axis: + before_axis_range *= value + elif i > axis: + after_axis_range *= value + before_axis_stride = axis_range * after_axis_range + full_range = before_axis_range * before_axis_stride def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): # pylint: disable=invalid-name @@ -71,30 +91,16 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): updates = ib.buffer_ptr(updates_ptr) out = ib.buffer_ptr(out_ptr) - # Prepare ranges and strides - before_axis_range = 1 - for i in data_ptr.shape[:axis]: - before_axis_range *= i - - axis_range = data_ptr.shape[axis] - - after_axis_range = 1 - for i in data_ptr.shape[axis + 1 :]: - after_axis_range *= i - before_axis_stride = axis_range * after_axis_range - # Copy initial input data to output - fused_shape = before_axis_range * before_axis_stride - - with ib.for_range(0, fused_shape) as i: + with ib.for_range(0, full_range, "i", kind="parallel") as i: out[i] = data[i] - with ib.for_range(0, before_axis_range, kind="parallel") as i: - with ib.for_range(0, after_axis_range, kind="parallel") as j: - with ib.for_range(0, axis_range, kind="parallel") as k: + with ib.for_range(0, before_axis_range, "i", kind="parallel") as i: + with ib.for_range(0, after_axis_range, "j", kind="parallel") as j: + with ib.for_range(0, axis_range, "k", kind="parallel") as k: pre_index = i * before_axis_stride + j index1 = pre_index + k * after_axis_range - index2 = pre_index + indices[index1] + index2 = pre_index + indices[index1] * after_axis_range if reduction == "update": out[index2] = updates[index1] elif reduction == "add": From c293094b957ef5ec81273b5764ec6e5e8b56282e Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Sat, 4 Feb 2023 18:52:34 +0300 Subject: [PATCH 10/37] fix parallel approach --- python/tvm/topi/scatter_elements.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/scatter_elements.py b/python/tvm/topi/scatter_elements.py index 40ceac383ab5..de0ecd749c3d 100644 --- a/python/tvm/topi/scatter_elements.py +++ b/python/tvm/topi/scatter_elements.py @@ -95,9 +95,10 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): with ib.for_range(0, full_range, "i", kind="parallel") as i: out[i] = data[i] + # TODO(vvchernov): find optimal parallel approach with ib.for_range(0, before_axis_range, "i", kind="parallel") as i: - with ib.for_range(0, after_axis_range, "j", kind="parallel") as j: - with ib.for_range(0, axis_range, "k", kind="parallel") as k: + with ib.for_range(0, after_axis_range, "j") as j: + with ib.for_range(0, axis_range, "k") as k: pre_index = i * before_axis_stride + j index1 = pre_index + k * after_axis_range index2 = pre_index + indices[index1] * after_axis_range From e92ffb9ec37c4cb778de66ae1db2644bf4c82849 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Sun, 5 Feb 2023 15:36:03 +0300 Subject: [PATCH 11/37] CI tests for scatter elements were added --- tests/python/frontend/onnx/test_forward.py | 72 +++++++++++++++++++++- 1 file changed, 71 insertions(+), 1 deletion(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 470a67e86c93..17a6715941b1 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -844,7 +844,7 @@ def verify_scatter(in_shape, indices, axis): updates = np.random.uniform(size=indices.shape).astype("float32") y = helper.make_node( - "ScatterElements", ["data", "indices", "updates"], ["output"], axis=axis + "Scatter", ["data", "indices", "updates"], ["output"], axis=axis ) graph = helper.make_graph( @@ -868,6 +868,76 @@ def verify_scatter(in_shape, indices, axis): verify_scatter((4, 3, 5, 6), [[[[2, 1, 0, 0]]]], 0) +@tvm.testing.parametrize_targets +def test_scatter_elements(target, dev): + """test_scatter_elements""" + + def verify_scatter_elements(in_shape, indices, axis=0, reduction="update"): + x = np.random.uniform(size=in_shape).astype("float32") + indices = np.array(indices, dtype="int32") + updates = np.random.uniform(size=indices.shape).astype("float32") + + scatter_elements_node = helper.make_node( + "ScatterElements", + ["data", "indices", "updates"], + ["output"], + axis=axis, + reduction=reduction, + ) + + graph = helper.make_graph( + [scatter_elements_node], + "scatter_elements_test", + inputs=[ + helper.make_tensor_value_info("data", TensorProto.FLOAT, list(in_shape)), + helper.make_tensor_value_info("indices", TensorProto.INT32, list(indices.shape)), + helper.make_tensor_value_info("updates", TensorProto.FLOAT, list(indices.shape)), + ], + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, list(in_shape))], + ) + model = helper.make_model(graph, producer_name="scatter_elements_test") + verify_with_ort_with_inputs(model, [x, indices, updates], target=target, dev=dev) + + # Usual scatter for 1d input + verify_scatter_elements((4,), [2, 3]) + # Usual scatter with specified positive axis + verify_scatter_elements((2, 2), [[1, 0], [0, 1]], 1) + # Usual scatter for 3d input with spicified negative indices and axis + verify_scatter_elements((3, 3, 3), [[[-1, -3]]], -1) + # Usual scatter for 4d input + verify_scatter_elements((4, 3, 5, 6), [[[[2, 1, 0, 0]]]]) + # Scatter elements with addition reduction of duplicates + verify_scatter_elements( + (3, 3, 3), + [[[0, 2, 1], [1, 1, 1], [2, 1, 0]], [[0, 2, 1], [1, 1, 1], [2, 1, 0]]], + 0, + "add", + ) + # Scatter elements with reduction and specified axis + verify_scatter_elements((3, 3, 3), [[[2, 2, 2], [1, 1, 1], [0, 0, 0]]], 2, "add") + # Scatter elements with multiplication reduction of duplicates + verify_scatter_elements( + (3, 3, 3), + [[[0, 2, 1], [1, 1, 1], [2, 1, 0]], [[0, 2, 1], [1, 1, 1], [2, 1, 0]]], + 0, + "mul", + ) + # Scatter elements with min reduction of duplicates + verify_scatter_elements( + (3, 3, 3), + [[[0, 2, 1], [1, 1, 1], [2, 1, 0]], [[0, 2, 1], [1, 1, 1], [2, 1, 0]]], + 0, + "min", + ) + # Scatter elements with max reduction of duplicates + verify_scatter_elements( + (3, 3, 3), + [[[0, 2, 1], [1, 1, 1], [2, 1, 0]], [[0, 2, 1], [1, 1, 1], [2, 1, 0]]], + 0, + "max", + ) + + @tvm.testing.parametrize_targets def test_slice(target, dev): """test_slice""" From f0fb416411e64b22752806b3f0c532119ae15204 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Sun, 5 Feb 2023 16:03:51 +0300 Subject: [PATCH 12/37] small update of description --- python/tvm/relay/op/transform.py | 15 ++++++++++----- tests/python/frontend/onnx/test_forward.py | 4 +--- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 1f97d21945bd..3df13da04426 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -403,7 +403,7 @@ def scatter_add(data, indices, updates, axis): return _make.scatter_add(data, indices, updates, axis) -def scatter_elements(data, indices, updates, axis, reduction="update"): +def scatter_elements(data, indices, updates, axis=0, reduction="update"): """Scatter elements with updating data by reduction of values in updates at positions defined by indices. @@ -416,14 +416,19 @@ def scatter_elements(data, indices, updates, axis, reduction="update"): The index locations to update. updates : relay.Expr - The values to add. + The values to update. axis : int - The axis to scatter_add on. + The axis to scatter elements on. It is zero by default. reduction : string, optional - The reduction mode for scatter. Choise is from ["update", "add", "mul", "min", max"]. - It is "update" by default. + The reduction mode for scatter. Choise is from ["update", "add", "mul", "min", max"] + If update, the update values will replace the input data + If add, the update values will be added to the input data + If mul, the update values will be multiply to the input data + If min, there is choice of minimal between the update values and the input data + If max, there is choice of maximal between the update values and the input data + It is "update" by default Returns ------- diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 17a6715941b1..8bf02b951f1c 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -843,9 +843,7 @@ def verify_scatter(in_shape, indices, axis): indices = np.array(indices, dtype="int32") updates = np.random.uniform(size=indices.shape).astype("float32") - y = helper.make_node( - "Scatter", ["data", "indices", "updates"], ["output"], axis=axis - ) + y = helper.make_node("Scatter", ["data", "indices", "updates"], ["output"], axis=axis) graph = helper.make_graph( [y], From 7f9128d2f5c5e2c6f54937267c4bd698a1a4e50b Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Sun, 5 Feb 2023 20:20:29 +0300 Subject: [PATCH 13/37] sphinx issue was fixed --- python/tvm/topi/scatter_elements.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/topi/scatter_elements.py b/python/tvm/topi/scatter_elements.py index de0ecd749c3d..ade7064e421f 100644 --- a/python/tvm/topi/scatter_elements.py +++ b/python/tvm/topi/scatter_elements.py @@ -29,11 +29,11 @@ def scatter_elements(data, indices, updates, axis=0, reduction="update"): .. code-block:: - output[indices[i][j]][j] = f(output[indices[i][j]][j], updates[i][j]) if axis = 0, - output[i][indices[i][j]] = f(output[i][indices[i][j]], updates[i][j]) if axis = 1, + output[indices[i][j]][j] = f(output[indices[i][j]][j], updates[i][j]) if axis = 0 + output[i][indices[i][j]] = f(output[i][indices[i][j]], updates[i][j]) if axis = 1 where the update function f is determinted by the reduction. - Five types of the function are supported: replace, +, *, min or max + Five types of the function are supported: replace, +, \*, min or max Parameters ---------- From 1e3a663316053f92b1488b50e392e1de644dd22c Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Sun, 5 Feb 2023 20:23:28 +0300 Subject: [PATCH 14/37] fix scatter deprecation in the CI test --- tests/python/frontend/onnx/test_forward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 8bf02b951f1c..e283dd75a113 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -856,7 +856,7 @@ def verify_scatter(in_shape, indices, axis): outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, list(in_shape))], ) model = helper.make_model(graph, producer_name="scatter_test") - verify_with_ort_with_inputs(model, [x, indices, updates], target=target, dev=dev) + verify_with_ort_with_inputs(model, [x, indices, updates], target=target, dev=dev, opset=16) verify_scatter((4,), [1], 0) verify_scatter((1, 4), [[0]], 0) From c2653b32723b957800cd7428273041e4d9497557 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Sun, 5 Feb 2023 20:35:53 +0300 Subject: [PATCH 15/37] fix --- python/tvm/topi/scatter_elements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/scatter_elements.py b/python/tvm/topi/scatter_elements.py index ade7064e421f..ea26fc0f3167 100644 --- a/python/tvm/topi/scatter_elements.py +++ b/python/tvm/topi/scatter_elements.py @@ -33,7 +33,7 @@ def scatter_elements(data, indices, updates, axis=0, reduction="update"): output[i][indices[i][j]] = f(output[i][indices[i][j]], updates[i][j]) if axis = 1 where the update function f is determinted by the reduction. - Five types of the function are supported: replace, +, \*, min or max + Five types of the function are supported: "update", "add", "mul", "min" and "max" (see below) Parameters ---------- From a7e0aae8e9e06d02bedd8b7b4f433de750d283c5 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 6 Feb 2023 08:37:06 +0300 Subject: [PATCH 16/37] fix scatter version support --- python/tvm/relay/frontend/onnx.py | 2 +- tests/python/frontend/onnx/test_forward.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index a5a0b14d5434..f8474d3e9dad 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2848,7 +2848,7 @@ class Scatter(OnnxOpConverter): """Operator converter for Scatter.""" @classmethod - def _impl_v1(cls, inputs, attr, params): + def _impl_v9(cls, inputs, attr, params): axis = attr.get("axis", 0) return _op.scatter(inputs[0], inputs[1], inputs[2], axis) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index e283dd75a113..86410545ecb6 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -856,7 +856,9 @@ def verify_scatter(in_shape, indices, axis): outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, list(in_shape))], ) model = helper.make_model(graph, producer_name="scatter_test") - verify_with_ort_with_inputs(model, [x, indices, updates], target=target, dev=dev, opset=16) + # Scatter operator has been supported from version 9 and + # deprecated since version 11 of the default ONNX operator set + verify_with_ort_with_inputs(model, [x, indices, updates], target=target, dev=dev, opset=9) verify_scatter((4,), [1], 0) verify_scatter((1, 4), [[0]], 0) From d527411f966e847792ce870f932ba034819ec2df Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 6 Feb 2023 08:51:37 +0300 Subject: [PATCH 17/37] fix negative indices --- python/tvm/topi/scatter_elements.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/tvm/topi/scatter_elements.py b/python/tvm/topi/scatter_elements.py index ea26fc0f3167..ce2c10d4472d 100644 --- a/python/tvm/topi/scatter_elements.py +++ b/python/tvm/topi/scatter_elements.py @@ -101,7 +101,11 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): with ib.for_range(0, axis_range, "k") as k: pre_index = i * before_axis_stride + j index1 = pre_index + k * after_axis_range - index2 = pre_index + indices[index1] * after_axis_range + # TODO(vvchernov): assert for out of bounds + k_new = indices[index1] + if k_new < 0: + k_new += axis_range + index2 = pre_index + k_new * after_axis_range if reduction == "update": out[index2] = updates[index1] elif reduction == "add": From 3525fcb76440df5c7f76c6bf057e3871e1fada9a Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 6 Feb 2023 09:52:08 +0300 Subject: [PATCH 18/37] add scatter elements strategy for cuda, gpu --- python/tvm/relay/op/strategy/cuda.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index fa295c93a19f..fc1691fe9ef0 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -1099,6 +1099,20 @@ def scatter_add_cuda(attrs, inputs, out_type, target): return strategy +@scatter_elements_strategy.register(["cuda", "gpu"]) +def scatter_elements_cuda(attrs, inputs, out_type, target): + """scatter elements cuda strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_scatter_elements(topi.cuda.scatter_elements), + wrap_topi_schedule(topi.cuda.schedule_extern), + name="scatter_elements.cuda", + plevel=10, + ) + # TODO(vvchernov): There is possible specification for rank=1 as for scatter + return strategy + + @scatter_nd_strategy.register(["cuda", "gpu"]) def scatter_nd_cuda(attrs, inputs, out_type, target): """scatter_nd cuda strategy""" From 38af70c7f05dc58505d7216964f9d6545675aec4 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 6 Feb 2023 11:00:21 +0300 Subject: [PATCH 19/37] update assert comment, update check of negative indices, hide tests for 18 version --- python/tvm/relay/frontend/onnx.py | 4 ++- python/tvm/topi/scatter_elements.py | 6 ++--- tests/python/frontend/onnx/test_forward.py | 29 +++++++++++----------- 3 files changed, 21 insertions(+), 18 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index f8474d3e9dad..cb72a0afe07f 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2874,7 +2874,9 @@ def _args_check(cls, inputs, attr, red_valids=None): reduction = attr.get("reduction", None) if reduction is None: reduction = "update" - assert reduction in red_valids + assert reduction in red_valids, "Only {} modes are supported, but {} is gotten".format( + red_valids, reduction + ) ret.append(reduction) return ret diff --git a/python/tvm/topi/scatter_elements.py b/python/tvm/topi/scatter_elements.py index ce2c10d4472d..28deebe85e97 100644 --- a/python/tvm/topi/scatter_elements.py +++ b/python/tvm/topi/scatter_elements.py @@ -101,10 +101,10 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): with ib.for_range(0, axis_range, "k") as k: pre_index = i * before_axis_stride + j index1 = pre_index + k * after_axis_range - # TODO(vvchernov): assert for out of bounds + # TODO(vvchernov): assert for out of bounds, separated check for indices k_new = indices[index1] - if k_new < 0: - k_new += axis_range + index_check = tir.LT(k_new, tir.const(0, indices.dtype)) + k_new += tir.Select(index_check, axis_range, tir.const(0, indices.dtype)) index2 = pre_index + k_new * after_axis_range if reduction == "update": out[index2] = updates[index1] diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 86410545ecb6..1b01e0a24b8d 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -922,20 +922,21 @@ def verify_scatter_elements(in_shape, indices, axis=0, reduction="update"): 0, "mul", ) - # Scatter elements with min reduction of duplicates - verify_scatter_elements( - (3, 3, 3), - [[[0, 2, 1], [1, 1, 1], [2, 1, 0]], [[0, 2, 1], [1, 1, 1], [2, 1, 0]]], - 0, - "min", - ) - # Scatter elements with max reduction of duplicates - verify_scatter_elements( - (3, 3, 3), - [[[0, 2, 1], [1, 1, 1], [2, 1, 0]], [[0, 2, 1], [1, 1, 1], [2, 1, 0]]], - 0, - "max", - ) + # TODO(vvchernov): min and mul options are supported from 18 version, but CI supports 17 only + # # Scatter elements with min reduction of duplicates + # verify_scatter_elements( + # (3, 3, 3), + # [[[0, 2, 1], [1, 1, 1], [2, 1, 0]], [[0, 2, 1], [1, 1, 1], [2, 1, 0]]], + # 0, + # "min", + # ) + # # Scatter elements with max reduction of duplicates + # verify_scatter_elements( + # (3, 3, 3), + # [[[0, 2, 1], [1, 1, 1], [2, 1, 0]], [[0, 2, 1], [1, 1, 1], [2, 1, 0]]], + # 0, + # "max", + # ) @tvm.testing.parametrize_targets From ab7cf51ad4ec0a25a26f62065abe1507dc0bb36a Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 6 Feb 2023 14:08:05 +0300 Subject: [PATCH 20/37] fixes --- python/tvm/relay/frontend/onnx.py | 3 ++- python/tvm/topi/scatter_elements.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index cb72a0afe07f..53eaacf4d1c5 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2873,7 +2873,8 @@ def _args_check(cls, inputs, attr, red_valids=None): if red_valids: reduction = attr.get("reduction", None) if reduction is None: - reduction = "update" + reduction = b"update" + reduction.decode("utf-8") assert reduction in red_valids, "Only {} modes are supported, but {} is gotten".format( red_valids, reduction ) diff --git a/python/tvm/topi/scatter_elements.py b/python/tvm/topi/scatter_elements.py index 28deebe85e97..f9c0d88e8605 100644 --- a/python/tvm/topi/scatter_elements.py +++ b/python/tvm/topi/scatter_elements.py @@ -66,7 +66,7 @@ def scatter_elements(data, indices, updates, axis=0, reduction="update"): axis = utils.get_const_int(axis) shape = data.shape - axis_range = shape[axis] + axis_range = int(shape[axis]) if axis < 0: axis = len(shape) + axis From 5984eb3e967aca7020097eb4684063499fac5a62 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 7 Feb 2023 12:29:30 +0300 Subject: [PATCH 21/37] extend error log for convenient analysis --- src/tir/ir/expr.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 40606761f804..a651c8cbd11a 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -393,7 +393,8 @@ Select::Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Sp ICHECK(false_value.defined()) << "ValueError: true_value is undefined"; ICHECK(condition.dtype().is_bool()); ICHECK(condition.dtype().lanes() == true_value.dtype().lanes() || condition.dtype().lanes() == 1); - ICHECK(false_value.dtype() == true_value.dtype()) << "TypeError: mismatched types"; + ICHECK(false_value.dtype() == true_value.dtype()) << "TypeError: mismatched types. " << + "False type: " << false_value.dtype() << "; True type: " << true_value.dtype(); ObjectPtr node = make_object(); node->dtype = true_value.dtype(); From 46199da4b997e476837540df1ec0787883ec4146 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 7 Feb 2023 13:37:38 +0300 Subject: [PATCH 22/37] lint fix --- src/tir/ir/expr.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index a651c8cbd11a..0bd681da2ea0 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -393,8 +393,9 @@ Select::Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Sp ICHECK(false_value.defined()) << "ValueError: true_value is undefined"; ICHECK(condition.dtype().is_bool()); ICHECK(condition.dtype().lanes() == true_value.dtype().lanes() || condition.dtype().lanes() == 1); - ICHECK(false_value.dtype() == true_value.dtype()) << "TypeError: mismatched types. " << - "False type: " << false_value.dtype() << "; True type: " << true_value.dtype(); + ICHECK(false_value.dtype() == true_value.dtype()) + << "TypeError: mismatched types. " + << "False type: " << false_value.dtype() << "; True type: " << true_value.dtype(); ObjectPtr node = make_object(); node->dtype = true_value.dtype(); From fd2ad9ad7a245942fc73b2a58e3fd5fafd030ae5 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 7 Feb 2023 14:44:09 +0300 Subject: [PATCH 23/37] fix --- python/tvm/relay/frontend/onnx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 53eaacf4d1c5..1d86127249d8 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2874,7 +2874,7 @@ def _args_check(cls, inputs, attr, red_valids=None): reduction = attr.get("reduction", None) if reduction is None: reduction = b"update" - reduction.decode("utf-8") + reduction = reduction.decode("utf-8") assert reduction in red_valids, "Only {} modes are supported, but {} is gotten".format( red_valids, reduction ) From e14e4dd39d6e2cf226f4be9d73fa6b9f612c6626 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 7 Feb 2023 16:15:37 +0300 Subject: [PATCH 24/37] sync dtypes --- python/tvm/topi/scatter_elements.py | 3 ++- src/tir/ir/expr.cc | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/python/tvm/topi/scatter_elements.py b/python/tvm/topi/scatter_elements.py index f9c0d88e8605..407022fe55e9 100644 --- a/python/tvm/topi/scatter_elements.py +++ b/python/tvm/topi/scatter_elements.py @@ -18,6 +18,7 @@ from tvm import te from tvm import tir from . import utils +from .math import cast def scatter_elements(data, indices, updates, axis=0, reduction="update"): @@ -66,7 +67,7 @@ def scatter_elements(data, indices, updates, axis=0, reduction="update"): axis = utils.get_const_int(axis) shape = data.shape - axis_range = int(shape[axis]) + axis_range = cast(shape[axis], indices.dtype) if axis < 0: axis = len(shape) + axis diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 0bd681da2ea0..d5caeab53922 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -394,8 +394,8 @@ Select::Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Sp ICHECK(condition.dtype().is_bool()); ICHECK(condition.dtype().lanes() == true_value.dtype().lanes() || condition.dtype().lanes() == 1); ICHECK(false_value.dtype() == true_value.dtype()) - << "TypeError: mismatched types. " - << "False type: " << false_value.dtype() << "; True type: " << true_value.dtype(); + << "TypeError: mismatched types. " + << "False type: " << false_value.dtype() << "; True type: " << true_value.dtype(); ObjectPtr node = make_object(); node->dtype = true_value.dtype(); From 77e2308872468bff71b193608ce8c1aba402143b Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 8 Feb 2023 08:55:53 +0300 Subject: [PATCH 25/37] update cpu tir for scatter elements by scan example --- python/tvm/topi/scatter_elements.py | 53 +++++++++++++++-------------- 1 file changed, 28 insertions(+), 25 deletions(-) diff --git a/python/tvm/topi/scatter_elements.py b/python/tvm/topi/scatter_elements.py index 407022fe55e9..adaff6a75bd4 100644 --- a/python/tvm/topi/scatter_elements.py +++ b/python/tvm/topi/scatter_elements.py @@ -97,31 +97,34 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): out[i] = data[i] # TODO(vvchernov): find optimal parallel approach - with ib.for_range(0, before_axis_range, "i", kind="parallel") as i: - with ib.for_range(0, after_axis_range, "j") as j: - with ib.for_range(0, axis_range, "k") as k: - pre_index = i * before_axis_stride + j - index1 = pre_index + k * after_axis_range - # TODO(vvchernov): assert for out of bounds, separated check for indices - k_new = indices[index1] - index_check = tir.LT(k_new, tir.const(0, indices.dtype)) - k_new += tir.Select(index_check, axis_range, tir.const(0, indices.dtype)) - index2 = pre_index + k_new * after_axis_range - if reduction == "update": - out[index2] = updates[index1] - elif reduction == "add": - out[index2] += updates[index1] - elif reduction == "mul": - out[index2] *= updates[index1] - elif reduction == "min": - tir.min(out[index2], updates[index1]) - elif reduction == "max": - tir.max(out[index2], updates[index1]) - else: - raise NotImplementedError( - "scatter_elements reduction not in [update, add, mul, min, max]:", - reduction, - ) + with ib.for_range( + 0, before_axis_range * after_axis_range, "fused", kind="parallel" + ) as fused: + i = fused // after_axis_range + j = fused % after_axis_range + pre_index = i * before_axis_stride + j + with ib.for_range(0, axis_range, "k") as k: + index1 = pre_index + k * after_axis_range + # TODO(vvchernov): assert for out of bounds, separated check for indices + k_new = indices[index1] + index_check = tir.LT(k_new, tir.const(0, indices.dtype)) + k_new += tir.Select(index_check, axis_range, tir.const(0, indices.dtype)) + index2 = pre_index + k_new * after_axis_range + if reduction == "update": + out[index2] = updates[index1] + elif reduction == "add": + out[index2] += updates[index1] + elif reduction == "mul": + out[index2] *= updates[index1] + elif reduction == "min": + tir.min(out[index2], updates[index1]) + elif reduction == "max": + tir.max(out[index2], updates[index1]) + else: + raise NotImplementedError( + "scatter_elements reduction not in [update, add, mul, min, max]:", + reduction, + ) return ib.get() From 3276dd71f6159927f379916881d00301740c3ebd Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 8 Feb 2023 09:38:01 +0300 Subject: [PATCH 26/37] scatter elements was basically implemented for topi/cuda --- python/tvm/topi/cuda/__init__.py | 1 + python/tvm/topi/cuda/scatter_elements.py | 157 +++++++++++++++++++++++ 2 files changed, 158 insertions(+) create mode 100644 python/tvm/topi/cuda/scatter_elements.py diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py index 31433e3a1a54..b746c95c0fc1 100644 --- a/python/tvm/topi/cuda/__init__.py +++ b/python/tvm/topi/cuda/__init__.py @@ -46,6 +46,7 @@ from .nms import get_valid_counts, non_max_suppression, all_class_non_max_suppression from .rcnn import * from .scatter import * +from .scatter_elements import * from .sort import * from .conv2d_nhwc_tensorcore import * from .conv3d_ndhwc_tensorcore import * diff --git a/python/tvm/topi/cuda/scatter_elements.py b/python/tvm/topi/cuda/scatter_elements.py new file mode 100644 index 000000000000..780107366f14 --- /dev/null +++ b/python/tvm/topi/cuda/scatter_elements.py @@ -0,0 +1,157 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Scatter operator """ +import tvm +from tvm import te +from tvm import tir +from ..utils import ceil_div, get_const_int +from ..math import cast + + +def scatter_elements(data, indices, updates, axis=0, reduction="update"): + """Scatter elements from updates to corresponding indices of copied data. + + Data, indices, updates and output have the same shape. + Indices can not have duplicates (if idx1 != idx2, then indices[idx1] != indices[idx2]) + if reduction == "update". + + .. code-block:: + + output[indices[i][j]][j] = f(output[indices[i][j]][j], updates[i][j]) if axis = 0 + output[i][indices[i][j]] = f(output[i][indices[i][j]], updates[i][j]) if axis = 1 + + where the update function f is determinted by the reduction. + Five types of the function are supported: "update", "add", "mul", "min" and "max" (see below) + + Parameters + ---------- + data : tvm.te.Tensor + The source array. + + indices : tvm.te.Tensor + The indices of the values to extract. + + updates : tvm.te.Tensor + The updates to apply at the Indices + + axis : optional, int + The axis to scatter on. It is zero by default. + + reduction : optional, string + The update mode for the algorithm, either "update", "add", "mul", "min" or "max" + If update, the update values will replace the input data + If add, the update values will be added to the input data + If mul, the update values will be multiply to the input data + If min, there is choice of minimal between the update values and the input data + If max, there is choice of maximal between the update values and the input data + It is "update" by default + + Returns + ------- + ret : tvm.te.Tensor + """ + if not isinstance(axis, int): + axis = get_const_int(axis) + + shape = data.shape + axis_range = cast(shape[axis], indices.dtype) + + if axis < 0: + axis = len(shape) + axis + + # Prepare ranges and strides + before_axis_range = 1 + after_axis_range = 1 + for i, value in enumerate(shape, 0): + if i < axis: + before_axis_range *= value + elif i > axis: + after_axis_range *= value + before_axis_stride = axis_range * after_axis_range + full_range = before_axis_range * before_axis_stride + full_range_excl_axis = before_axis_range * after_axis_range + + def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): + ib = tir.ir_builder.create() + + data = ib.buffer_ptr(data_ptr) + indices = ib.buffer_ptr(indices_ptr) + updates = ib.buffer_ptr(updates_ptr) + out = ib.buffer_ptr(out_ptr) + + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + # Copy initial input data to output + with ib.new_scope(): + num_blocks = ceil_div(full_range, max_threads) + bx = te.thread_axis("blockIdx.x") + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(bx, "thread_extent", num_blocks) + ib.scope_attr(tx, "thread_extent", max_threads) + + index = bx * max_threads + tx + with ib.if_scope(index < full_range): + out[index] = data[index] + + # TODO (vvchernov): use atomic function for special conditions (see cuda.scatter_nd) + with ib.new_scope(): + num_blocks_2 = ceil_div(full_range_excl_axis, max_threads) + bx2 = te.thread_axis("blockIdx.x") + tx2 = te.thread_axis("threadIdx.x") + ib.scope_attr(bx2, "thread_extent", num_blocks_2) + ib.scope_attr(tx2, "thread_extent", max_threads) + + fused = bx2 * max_threads + tx2 + with ib.if_scope(fused < full_range_excl_axis): + i = fused // after_axis_range + j = fused % after_axis_range + pre_index = i * before_axis_stride + j + with ib.for_range(0, axis_range, "k") as k: + index1 = pre_index + k * after_axis_range + # TODO(vvchernov): assert for out of bounds, separated check for indices + k_new = indices[index1] + index_check = tir.LT(k_new, tir.const(0, indices.dtype)) + k_new += tir.Select(index_check, axis_range, tir.const(0, indices.dtype)) + index2 = pre_index + k_new * after_axis_range + if reduction == "update": + out[index2] = updates[index1] + elif reduction == "add": + out[index2] += updates[index1] + elif reduction == "mul": + out[index2] *= updates[index1] + elif reduction == "min": + tir.min(out[index2], updates[index1]) + elif reduction == "max": + tir.max(out[index2], updates[index1]) + else: + raise NotImplementedError( + "scatter_elements reduction not in [update, add, mul, min, max]:", + reduction, + ) + + return ib.get() + + out_buf = tir.decl_buffer(data.shape, data.dtype, "out_buf") + return te.extern( + [data.shape], + [data, indices, updates], + lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]), + dtype=data.dtype, + out_buffers=[out_buf], + name="scatter_elements_cuda", + tag="scatter_elements_cuda", + ) From da839a45e31badc5947fe2c798aac1054aeb3b52 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 8 Feb 2023 12:40:40 +0300 Subject: [PATCH 27/37] fix cpu scatter elements --- python/tvm/topi/scatter_elements.py | 41 ++++++++++++++++++----------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/python/tvm/topi/scatter_elements.py b/python/tvm/topi/scatter_elements.py index adaff6a75bd4..600920dd82cb 100644 --- a/python/tvm/topi/scatter_elements.py +++ b/python/tvm/topi/scatter_elements.py @@ -66,22 +66,32 @@ def scatter_elements(data, indices, updates, axis=0, reduction="update"): if not isinstance(axis, int): axis = utils.get_const_int(axis) - shape = data.shape - axis_range = cast(shape[axis], indices.dtype) - if axis < 0: axis = len(shape) + axis # Prepare ranges and strides - before_axis_range = 1 + shape = data.shape + axis_range = cast(shape[axis], indices.dtype) + + full_range = 1 after_axis_range = 1 for i, value in enumerate(shape, 0): - if i < axis: - before_axis_range *= value - elif i > axis: + full_range *= value + if i > axis: after_axis_range *= value before_axis_stride = axis_range * after_axis_range - full_range = before_axis_range * before_axis_stride + + ind_shape = indices.shape + ind_axis_range = ind_shape[axis] + + ind_before_axis_range = 1 + ind_after_axis_range = 1 + for i, value in enumerate(ind_shape, 0): + if i < axis: + ind_before_axis_range *= value + elif i > axis: + ind_after_axis_range *= value + ind_before_axis_stride = ind_axis_range * ind_after_axis_range def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): # pylint: disable=invalid-name @@ -98,18 +108,19 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): # TODO(vvchernov): find optimal parallel approach with ib.for_range( - 0, before_axis_range * after_axis_range, "fused", kind="parallel" + 0, ind_before_axis_range * ind_after_axis_range, "fused", kind="parallel" ) as fused: - i = fused // after_axis_range - j = fused % after_axis_range - pre_index = i * before_axis_stride + j - with ib.for_range(0, axis_range, "k") as k: - index1 = pre_index + k * after_axis_range + i = fused // ind_after_axis_range + j = fused % ind_after_axis_range + with ib.for_range(0, ind_axis_range, "k") as k: + # Offset allong indices or updates + index1 = i * ind_before_axis_stride + k * ind_after_axis_range + j # TODO(vvchernov): assert for out of bounds, separated check for indices k_new = indices[index1] index_check = tir.LT(k_new, tir.const(0, indices.dtype)) k_new += tir.Select(index_check, axis_range, tir.const(0, indices.dtype)) - index2 = pre_index + k_new * after_axis_range + # Offset allong data + index2 = i * before_axis_stride + k_new * after_axis_range + j if reduction == "update": out[index2] = updates[index1] elif reduction == "add": From b0e1f1212a87e3090d375573138edf73462d3c4f Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 8 Feb 2023 13:03:08 +0300 Subject: [PATCH 28/37] fix gpu scatter elements --- python/tvm/topi/cuda/scatter_elements.py | 39 ++++++++++++++++-------- python/tvm/topi/scatter_elements.py | 4 +-- 2 files changed, 28 insertions(+), 15 deletions(-) diff --git a/python/tvm/topi/cuda/scatter_elements.py b/python/tvm/topi/cuda/scatter_elements.py index 780107366f14..a6fa17c1f3fe 100644 --- a/python/tvm/topi/cuda/scatter_elements.py +++ b/python/tvm/topi/cuda/scatter_elements.py @@ -68,13 +68,13 @@ def scatter_elements(data, indices, updates, axis=0, reduction="update"): if not isinstance(axis, int): axis = get_const_int(axis) - shape = data.shape - axis_range = cast(shape[axis], indices.dtype) - if axis < 0: axis = len(shape) + axis # Prepare ranges and strides + shape = data.shape + axis_range = cast(shape[axis], indices.dtype) + before_axis_range = 1 after_axis_range = 1 for i, value in enumerate(shape, 0): @@ -84,7 +84,19 @@ def scatter_elements(data, indices, updates, axis=0, reduction="update"): after_axis_range *= value before_axis_stride = axis_range * after_axis_range full_range = before_axis_range * before_axis_stride - full_range_excl_axis = before_axis_range * after_axis_range + + ind_shape = indices.shape + ind_axis_range = shape[axis] + + ind_before_axis_range = 1 + ind_after_axis_range = 1 + for i, value in enumerate(ind_shape, 0): + if i < axis: + ind_before_axis_range *= value + elif i > axis: + ind_after_axis_range *= value + ind_before_axis_stride = ind_axis_range * ind_after_axis_range + ind_full_range_excl_axis = ind_before_axis_range * ind_after_axis_range def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): ib = tir.ir_builder.create() @@ -109,24 +121,25 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): # TODO (vvchernov): use atomic function for special conditions (see cuda.scatter_nd) with ib.new_scope(): - num_blocks_2 = ceil_div(full_range_excl_axis, max_threads) + num_blocks_2 = ceil_div(ind_full_range_excl_axis, max_threads) bx2 = te.thread_axis("blockIdx.x") tx2 = te.thread_axis("threadIdx.x") ib.scope_attr(bx2, "thread_extent", num_blocks_2) ib.scope_attr(tx2, "thread_extent", max_threads) - fused = bx2 * max_threads + tx2 - with ib.if_scope(fused < full_range_excl_axis): - i = fused // after_axis_range - j = fused % after_axis_range - pre_index = i * before_axis_stride + j - with ib.for_range(0, axis_range, "k") as k: - index1 = pre_index + k * after_axis_range + ind_fused = bx2 * max_threads + tx2 + with ib.if_scope(ind_fused < ind_full_range_excl_axis): + i = ind_fused // ind_after_axis_range + j = ind_fused % ind_after_axis_range + with ib.for_range(0, ind_axis_range, "k") as k: + # Offset along indices or updates + index1 = i * ind_before_axis_stride + k * ind_after_axis_range + j # TODO(vvchernov): assert for out of bounds, separated check for indices k_new = indices[index1] index_check = tir.LT(k_new, tir.const(0, indices.dtype)) k_new += tir.Select(index_check, axis_range, tir.const(0, indices.dtype)) - index2 = pre_index + k_new * after_axis_range + # Offset along data + index2 = i * before_axis_stride + k_new * after_axis_range + j if reduction == "update": out[index2] = updates[index1] elif reduction == "add": diff --git a/python/tvm/topi/scatter_elements.py b/python/tvm/topi/scatter_elements.py index 600920dd82cb..170113a136a0 100644 --- a/python/tvm/topi/scatter_elements.py +++ b/python/tvm/topi/scatter_elements.py @@ -113,13 +113,13 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): i = fused // ind_after_axis_range j = fused % ind_after_axis_range with ib.for_range(0, ind_axis_range, "k") as k: - # Offset allong indices or updates + # Offset along indices or updates index1 = i * ind_before_axis_stride + k * ind_after_axis_range + j # TODO(vvchernov): assert for out of bounds, separated check for indices k_new = indices[index1] index_check = tir.LT(k_new, tir.const(0, indices.dtype)) k_new += tir.Select(index_check, axis_range, tir.const(0, indices.dtype)) - # Offset allong data + # Offset along data index2 = i * before_axis_stride + k_new * after_axis_range + j if reduction == "update": out[index2] = updates[index1] From c617f22e1a1cf0110dbf07ab62ce6ebbca9761a9 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 8 Feb 2023 13:05:33 +0300 Subject: [PATCH 29/37] fix --- python/tvm/topi/cuda/scatter_elements.py | 5 ++--- python/tvm/topi/scatter_elements.py | 5 ++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/python/tvm/topi/cuda/scatter_elements.py b/python/tvm/topi/cuda/scatter_elements.py index a6fa17c1f3fe..f377bbd19323 100644 --- a/python/tvm/topi/cuda/scatter_elements.py +++ b/python/tvm/topi/cuda/scatter_elements.py @@ -68,11 +68,10 @@ def scatter_elements(data, indices, updates, axis=0, reduction="update"): if not isinstance(axis, int): axis = get_const_int(axis) - if axis < 0: - axis = len(shape) + axis - # Prepare ranges and strides shape = data.shape + if axis < 0: + axis = len(shape) + axis axis_range = cast(shape[axis], indices.dtype) before_axis_range = 1 diff --git a/python/tvm/topi/scatter_elements.py b/python/tvm/topi/scatter_elements.py index 170113a136a0..e7948f6cb6c7 100644 --- a/python/tvm/topi/scatter_elements.py +++ b/python/tvm/topi/scatter_elements.py @@ -66,11 +66,10 @@ def scatter_elements(data, indices, updates, axis=0, reduction="update"): if not isinstance(axis, int): axis = utils.get_const_int(axis) - if axis < 0: - axis = len(shape) + axis - # Prepare ranges and strides shape = data.shape + if axis < 0: + axis = len(shape) + axis axis_range = cast(shape[axis], indices.dtype) full_range = 1 From 21ce735d21f1fe33703aa77347c35153b3281b8a Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 8 Feb 2023 19:49:33 +0300 Subject: [PATCH 30/37] small update --- python/tvm/topi/cuda/scatter_elements.py | 27 ++++++++++++------------ 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/python/tvm/topi/cuda/scatter_elements.py b/python/tvm/topi/cuda/scatter_elements.py index f377bbd19323..59dc00bcbab3 100644 --- a/python/tvm/topi/cuda/scatter_elements.py +++ b/python/tvm/topi/cuda/scatter_elements.py @@ -17,8 +17,7 @@ # pylint: disable=invalid-name """Scatter operator """ import tvm -from tvm import te -from tvm import tir +from tvm import te, tir from ..utils import ceil_div, get_const_int from ..math import cast @@ -97,13 +96,13 @@ def scatter_elements(data, indices, updates, axis=0, reduction="update"): ind_before_axis_stride = ind_axis_range * ind_after_axis_range ind_full_range_excl_axis = ind_before_axis_range * ind_after_axis_range - def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): + def gen_ir(data, indices, updates, out): ib = tir.ir_builder.create() - data = ib.buffer_ptr(data_ptr) - indices = ib.buffer_ptr(indices_ptr) - updates = ib.buffer_ptr(updates_ptr) - out = ib.buffer_ptr(out_ptr) + data_ptr = ib.buffer_ptr(data) + indices_ptr = ib.buffer_ptr(indices) + updates_ptr = ib.buffer_ptr(updates) + out_ptr = ib.buffer_ptr(out) max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) # Copy initial input data to output @@ -116,7 +115,7 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): index = bx * max_threads + tx with ib.if_scope(index < full_range): - out[index] = data[index] + out_ptr[index] = data_ptr[index] # TODO (vvchernov): use atomic function for special conditions (see cuda.scatter_nd) with ib.new_scope(): @@ -134,21 +133,21 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): # Offset along indices or updates index1 = i * ind_before_axis_stride + k * ind_after_axis_range + j # TODO(vvchernov): assert for out of bounds, separated check for indices - k_new = indices[index1] + k_new = indices_ptr[index1] index_check = tir.LT(k_new, tir.const(0, indices.dtype)) k_new += tir.Select(index_check, axis_range, tir.const(0, indices.dtype)) # Offset along data index2 = i * before_axis_stride + k_new * after_axis_range + j if reduction == "update": - out[index2] = updates[index1] + out_ptr[index2] = updates_ptr[index1] elif reduction == "add": - out[index2] += updates[index1] + out_ptr[index2] += updates_ptr[index1] elif reduction == "mul": - out[index2] *= updates[index1] + out_ptr[index2] *= updates_ptr[index1] elif reduction == "min": - tir.min(out[index2], updates[index1]) + out_ptr[index2] = tir.min(out_ptr[index2], updates_ptr[index1]) elif reduction == "max": - tir.max(out[index2], updates[index1]) + out_ptr[index2] = tir.max(out_ptr[index2], updates_ptr[index1]) else: raise NotImplementedError( "scatter_elements reduction not in [update, add, mul, min, max]:", From 1ebe71d8fb1456f0f02c2889f90c23aa8d325dca Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 8 Feb 2023 21:24:17 +0300 Subject: [PATCH 31/37] transfer indices check out of general loop --- python/tvm/topi/cuda/scatter_elements.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/python/tvm/topi/cuda/scatter_elements.py b/python/tvm/topi/cuda/scatter_elements.py index 59dc00bcbab3..327bfe4fbd91 100644 --- a/python/tvm/topi/cuda/scatter_elements.py +++ b/python/tvm/topi/cuda/scatter_elements.py @@ -94,6 +94,7 @@ def scatter_elements(data, indices, updates, axis=0, reduction="update"): elif i > axis: ind_after_axis_range *= value ind_before_axis_stride = ind_axis_range * ind_after_axis_range + ind_full_range = ind_before_axis_range * ind_before_axis_stride ind_full_range_excl_axis = ind_before_axis_range * ind_after_axis_range def gen_ir(data, indices, updates, out): @@ -117,6 +118,20 @@ def gen_ir(data, indices, updates, out): with ib.if_scope(index < full_range): out_ptr[index] = data_ptr[index] + # Check indices and shift to positive side if need + with ib.new_scope(): + num_blocks_1 = ceil_div(ind_full_range, max_threads) + bx1 = te.thread_axis("blockIdx.x") + tx1 = te.thread_axis("threadIdx.x") + ib.scope_attr(bx1, "thread_extent", num_blocks_1) + ib.scope_attr(tx1, "thread_extent", max_threads) + + ind_fused = bx2 * max_threads + tx2 + with ib.if_scope(ind_fused < ind_full_range): + index_check = tir.LT(indices_ptr[ind_fused], tir.const(0, indices.dtype)) + indices_ptr[ind_fused] += tir.Select(index_check, axis_range, tir.const(0, indices.dtype)) + # TODO(vvchernov): assert for index out of bounds + # TODO (vvchernov): use atomic function for special conditions (see cuda.scatter_nd) with ib.new_scope(): num_blocks_2 = ceil_div(ind_full_range_excl_axis, max_threads) @@ -132,12 +147,8 @@ def gen_ir(data, indices, updates, out): with ib.for_range(0, ind_axis_range, "k") as k: # Offset along indices or updates index1 = i * ind_before_axis_stride + k * ind_after_axis_range + j - # TODO(vvchernov): assert for out of bounds, separated check for indices - k_new = indices_ptr[index1] - index_check = tir.LT(k_new, tir.const(0, indices.dtype)) - k_new += tir.Select(index_check, axis_range, tir.const(0, indices.dtype)) # Offset along data - index2 = i * before_axis_stride + k_new * after_axis_range + j + index2 = i * before_axis_stride + indices_ptr[index1] * after_axis_range + j if reduction == "update": out_ptr[index2] = updates_ptr[index1] elif reduction == "add": From 5721d94d771b1b60d25d33b616b50e5afb43e151 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 9 Feb 2023 09:17:23 +0300 Subject: [PATCH 32/37] trancsfer ranges and strides calculation to gpu device --- python/tvm/topi/cuda/scatter_elements.py | 60 ++++++++++++------------ 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/python/tvm/topi/cuda/scatter_elements.py b/python/tvm/topi/cuda/scatter_elements.py index 327bfe4fbd91..f0a83555f9d7 100644 --- a/python/tvm/topi/cuda/scatter_elements.py +++ b/python/tvm/topi/cuda/scatter_elements.py @@ -67,36 +67,6 @@ def scatter_elements(data, indices, updates, axis=0, reduction="update"): if not isinstance(axis, int): axis = get_const_int(axis) - # Prepare ranges and strides - shape = data.shape - if axis < 0: - axis = len(shape) + axis - axis_range = cast(shape[axis], indices.dtype) - - before_axis_range = 1 - after_axis_range = 1 - for i, value in enumerate(shape, 0): - if i < axis: - before_axis_range *= value - elif i > axis: - after_axis_range *= value - before_axis_stride = axis_range * after_axis_range - full_range = before_axis_range * before_axis_stride - - ind_shape = indices.shape - ind_axis_range = shape[axis] - - ind_before_axis_range = 1 - ind_after_axis_range = 1 - for i, value in enumerate(ind_shape, 0): - if i < axis: - ind_before_axis_range *= value - elif i > axis: - ind_after_axis_range *= value - ind_before_axis_stride = ind_axis_range * ind_after_axis_range - ind_full_range = ind_before_axis_range * ind_before_axis_stride - ind_full_range_excl_axis = ind_before_axis_range * ind_after_axis_range - def gen_ir(data, indices, updates, out): ib = tir.ir_builder.create() @@ -105,6 +75,36 @@ def gen_ir(data, indices, updates, out): updates_ptr = ib.buffer_ptr(updates) out_ptr = ib.buffer_ptr(out) + # Prepare ranges and strides + shape = data.shape + if axis < 0: + axis = len(shape) + axis + axis_range = cast(shape[axis], indices.dtype) + + before_axis_range = 1 + after_axis_range = 1 + for i, value in enumerate(shape, 0): + if i < axis: + before_axis_range *= value + elif i > axis: + after_axis_range *= value + before_axis_stride = axis_range * after_axis_range + full_range = before_axis_range * before_axis_stride + + ind_shape = indices.shape + ind_axis_range = shape[axis] + + ind_before_axis_range = 1 + ind_after_axis_range = 1 + for i, value in enumerate(ind_shape, 0): + if i < axis: + ind_before_axis_range *= value + elif i > axis: + ind_after_axis_range *= value + ind_before_axis_stride = ind_axis_range * ind_after_axis_range + ind_full_range = ind_before_axis_range * ind_before_axis_stride + ind_full_range_excl_axis = ind_before_axis_range * ind_after_axis_range + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) # Copy initial input data to output with ib.new_scope(): From 2bbff6ad187d1dab18108441a28d2c00c6002ef6 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 9 Feb 2023 09:47:57 +0300 Subject: [PATCH 33/37] fixes --- python/tvm/topi/cuda/scatter_elements.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/cuda/scatter_elements.py b/python/tvm/topi/cuda/scatter_elements.py index f0a83555f9d7..8bacea5e29f7 100644 --- a/python/tvm/topi/cuda/scatter_elements.py +++ b/python/tvm/topi/cuda/scatter_elements.py @@ -126,10 +126,12 @@ def gen_ir(data, indices, updates, out): ib.scope_attr(bx1, "thread_extent", num_blocks_1) ib.scope_attr(tx1, "thread_extent", max_threads) - ind_fused = bx2 * max_threads + tx2 + ind_fused = bx1 * max_threads + tx1 with ib.if_scope(ind_fused < ind_full_range): index_check = tir.LT(indices_ptr[ind_fused], tir.const(0, indices.dtype)) - indices_ptr[ind_fused] += tir.Select(index_check, axis_range, tir.const(0, indices.dtype)) + indices_ptr[ind_fused] += tir.Select( + index_check, axis_range, tir.const(0, indices.dtype) + ) # TODO(vvchernov): assert for index out of bounds # TODO (vvchernov): use atomic function for special conditions (see cuda.scatter_nd) From fe27ea8994830c0cffe168561b8a21d397280acc Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 9 Feb 2023 13:59:51 +0300 Subject: [PATCH 34/37] fix axis --- python/tvm/topi/cuda/scatter_elements.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/topi/cuda/scatter_elements.py b/python/tvm/topi/cuda/scatter_elements.py index 8bacea5e29f7..8a27fa56fe7b 100644 --- a/python/tvm/topi/cuda/scatter_elements.py +++ b/python/tvm/topi/cuda/scatter_elements.py @@ -67,7 +67,7 @@ def scatter_elements(data, indices, updates, axis=0, reduction="update"): if not isinstance(axis, int): axis = get_const_int(axis) - def gen_ir(data, indices, updates, out): + def gen_ir(data, indices, updates, out, axis): ib = tir.ir_builder.create() data_ptr = ib.buffer_ptr(data) @@ -92,7 +92,7 @@ def gen_ir(data, indices, updates, out): full_range = before_axis_range * before_axis_stride ind_shape = indices.shape - ind_axis_range = shape[axis] + ind_axis_range = ind_shape[axis] ind_before_axis_range = 1 ind_after_axis_range = 1 @@ -173,7 +173,7 @@ def gen_ir(data, indices, updates, out): return te.extern( [data.shape], [data, indices, updates], - lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]), + lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0], axis), dtype=data.dtype, out_buffers=[out_buf], name="scatter_elements_cuda", From 62c56b7baf07319a9ba8580011ca82b590c2e55f Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 10 Feb 2023 08:16:09 +0300 Subject: [PATCH 35/37] clean code --- python/tvm/relay/frontend/onnx.py | 2 +- tests/python/frontend/onnx/test_forward.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 1d86127249d8..2a1890627225 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2886,7 +2886,7 @@ def _args_check(cls, inputs, attr, red_valids=None): def _impl_v11(cls, inputs, attr, params): axis = cls._args_check(inputs, attr) - return _op.scatter(inputs[0], inputs[1], inputs[2], axis) + return _op.scatter_elements(inputs[0], inputs[1], inputs[2], axis, "update") @classmethod def _impl_v16(cls, inputs, attr, params): diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 1b01e0a24b8d..293f4d38e649 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -922,7 +922,7 @@ def verify_scatter_elements(in_shape, indices, axis=0, reduction="update"): 0, "mul", ) - # TODO(vvchernov): min and mul options are supported from 18 version, but CI supports 17 only + # TODO(vvchernov): min and max options are supported from 18 version, but CI supports 17 only # # Scatter elements with min reduction of duplicates # verify_scatter_elements( # (3, 3, 3), @@ -5468,7 +5468,6 @@ def verify_eyelike(indata, dynamic=False): "test_reduce_sum_negative_axes_keepdims_example", "test_reduce_sum_negative_axes_keepdims_random", "test_roialign_aligned_true", - "test_scatter_elements_with_duplicate_indices", "test_sequence_insert_at_back", "test_sequence_insert_at_front", "test_sequence_map_add_1_sequence_1_tensor", From 1fa653fb8c31cd889e5dc3bd2374164c2ddcb7d1 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 16 Feb 2023 12:21:12 +0300 Subject: [PATCH 36/37] fix after review --- python/tvm/topi/cuda/scatter_elements.py | 21 ++++----------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/python/tvm/topi/cuda/scatter_elements.py b/python/tvm/topi/cuda/scatter_elements.py index 8a27fa56fe7b..3e6eebf81915 100644 --- a/python/tvm/topi/cuda/scatter_elements.py +++ b/python/tvm/topi/cuda/scatter_elements.py @@ -118,22 +118,6 @@ def gen_ir(data, indices, updates, out, axis): with ib.if_scope(index < full_range): out_ptr[index] = data_ptr[index] - # Check indices and shift to positive side if need - with ib.new_scope(): - num_blocks_1 = ceil_div(ind_full_range, max_threads) - bx1 = te.thread_axis("blockIdx.x") - tx1 = te.thread_axis("threadIdx.x") - ib.scope_attr(bx1, "thread_extent", num_blocks_1) - ib.scope_attr(tx1, "thread_extent", max_threads) - - ind_fused = bx1 * max_threads + tx1 - with ib.if_scope(ind_fused < ind_full_range): - index_check = tir.LT(indices_ptr[ind_fused], tir.const(0, indices.dtype)) - indices_ptr[ind_fused] += tir.Select( - index_check, axis_range, tir.const(0, indices.dtype) - ) - # TODO(vvchernov): assert for index out of bounds - # TODO (vvchernov): use atomic function for special conditions (see cuda.scatter_nd) with ib.new_scope(): num_blocks_2 = ceil_div(ind_full_range_excl_axis, max_threads) @@ -149,8 +133,11 @@ def gen_ir(data, indices, updates, out, axis): with ib.for_range(0, ind_axis_range, "k") as k: # Offset along indices or updates index1 = i * ind_before_axis_stride + k * ind_after_axis_range + j + # Get index and shift to positive side if need + new_index = indices_ptr[index1] + shifted_index = new_index + (new_index < 0) * axis_range # Offset along data - index2 = i * before_axis_stride + indices_ptr[index1] * after_axis_range + j + index2 = i * before_axis_stride + shifted_index * after_axis_range + j if reduction == "update": out_ptr[index2] = updates_ptr[index1] elif reduction == "add": From ac7c2301a2a1f5df294593cf1895b92adda64d39 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 16 Feb 2023 12:38:17 +0300 Subject: [PATCH 37/37] fix lint --- python/tvm/topi/cuda/scatter_elements.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/topi/cuda/scatter_elements.py b/python/tvm/topi/cuda/scatter_elements.py index 3e6eebf81915..8ed3e2972081 100644 --- a/python/tvm/topi/cuda/scatter_elements.py +++ b/python/tvm/topi/cuda/scatter_elements.py @@ -102,7 +102,6 @@ def gen_ir(data, indices, updates, out, axis): elif i > axis: ind_after_axis_range *= value ind_before_axis_stride = ind_axis_range * ind_after_axis_range - ind_full_range = ind_before_axis_range * ind_before_axis_stride ind_full_range_excl_axis = ind_before_axis_range * ind_after_axis_range max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)