diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index b0c8108b16247..eb73427f184cd 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -101,6 +101,14 @@ struct ScatterAttrs : public tvm::AttrsNode { } }; +struct ScatterAddAttrs : public tvm::AttrsNode { + Integer axis; + + TVM_DECLARE_ATTRS(ScatterAddAttrs, "relay.attrs.ScatterAddAttrs") { + TVM_ATTR_FIELD(axis).set_default(0).describe("The axis over which to select values."); + } +}; + struct GatherAttrs : public tvm::AttrsNode { Integer axis; diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index dc1265870475d..a2c374d477981 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -101,6 +101,14 @@ def compute_scatter(attrs, inputs, output_type): _reg.register_schedule("scatter", strategy.schedule_scatter) +# scatter_add +@_reg.register_compute("scatter_add") +def compute_scatter_add(attrs, inputs, output_type): + """Compute definition of scatter_add""" + return [topi.scatter_add(inputs[0], inputs[1], inputs[2], attrs.axis)] + +_reg.register_schedule("scatter_add", strategy.schedule_scatter_add) + ##################### # Shape functions # ##################### @@ -396,6 +404,7 @@ def argwhere_shape_func(attrs, inputs, out_ndims): return ValueError("Does not support rank higher than 5 in argwhere") _reg.register_shape_func("scatter", False, elemwise_shape_func) +_reg.register_shape_func("scatter_add", False, elemwise_shape_func) @script def _layout_transform_shape_func(data_shape, diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index db0577cf8bdf6..62c2948b51e2c 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -842,6 +842,13 @@ def schedule_scatter(attrs, outs, target): with target: return topi.generic.schedule_scatter(outs) +# scatter_add +@generic_func +def schedule_scatter_add(attrs, outs, target): + """schedule scatter_add""" + with target: + return topi.generic.schedule_scatter_add(outs) + # bitserial_conv2d def wrap_compute_bitserial_conv2d(topi_compute): """wrap bitserial_conv2d topi compute""" diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index ae10dd50a87f7..6f23af2c0cee0 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -275,6 +275,30 @@ def scatter(data, indices, updates, axis): """ return _make.scatter(data, indices, updates, axis) +def scatter_add(data, indices, updates, axis): + """Update data by adding values in updates at positions defined by indices + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + indices : relay.Expr + The index locations to update. + + updates : relay.Expr + The values to add. + + axis : int + The axis to scatter_add on + + Returns + ------- + ret : relay.Expr + The computed result. + """ + return _make.scatter_add(data, indices, updates, axis) + def reshape_like(data, shape_like): """Reshapes the input array by the size of another array. For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index cc1150cb9bae4..1b072532d8b70 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -806,6 +806,55 @@ RELAY_REGISTER_OP("scatter") .set_attr("TOpPattern", kOpaque) .set_support_level(10); +// Scatter_add +TVM_REGISTER_NODE_TYPE(ScatterAddAttrs); + +// Scatter Add +bool ScatterAddRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(num_inputs, 3); + CHECK_EQ(types.size(), 4); + auto data = types[0].as(); + if (data == nullptr) { + return false; + } + auto indices = types[1].as(); + if (indices == nullptr) { + return false; + } + auto updates = types[2].as(); + if (updates == nullptr) { + return false; + } + CHECK(indices->dtype.is_int()) << "indices of scatter_add must be tensor of integer"; + const auto param = attrs.as(); + CHECK(param != nullptr); + reporter->Assign(types[3], TensorType(data->shape, data->dtype)); + return true; +} + +TVM_REGISTER_GLOBAL("relay.op._make.scatter_add") + .set_body_typed([](Expr data, Expr indices, Expr updates, int axis) { + auto attrs = make_object(); + attrs->axis = std::move(axis); + static const Op& op = Op::Get("scatter_add"); + return Call(op, {data, indices, updates}, Attrs(attrs), {}); + }); + +RELAY_REGISTER_OP("scatter_add") + .describe( + R"doc(Update data by adding values in updates at positions defined by indices)doc" TVM_ADD_FILELINE) + .set_num_inputs(3) + .add_argument("data", "Tensor", "The input data tensor.") + .add_argument("indicies", "Tensor", "The indicies location tensor.") + .add_argument("updates", "Tensor", "The values to update the input with.") + .add_type_rel("ScatterAdd", ScatterAddRel) + .set_attr("TOpIsStateful", false) + .set_attr("TOpPattern", kOpaque) + .set_support_level(10); + +//// + // Take TVM_REGISTER_NODE_TYPE(TakeAttrs); diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 115900fea0f3f..0445c98342b8f 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -811,6 +811,51 @@ def verify_scatter(dshape, ishape, axis=0): verify_scatter((16, 16, 4, 5), (16, 16, 4, 5), 3) +def test_scatter_add(): + + def ref_scatter_add(data, indices, updates, axis=0): + output = np.copy(data) + for index in np.ndindex(*indices.shape): + new_index = list(index) + new_index[axis] = indices[index] + output[tuple(new_index)] += updates[index] + return output + + def verify_scatter_add(dshape, ishape, axis=0): + d = relay.var("d", relay.TensorType(dshape, "float32")) + i = relay.var("i", relay.TensorType(ishape, "int64")) + u = relay.var("u", relay.TensorType(ishape, "float32")) + z = relay.op.scatter_add(d, i, u, axis) + + func = relay.Function([d, i, u], z) + + data_np = np.random.uniform(size=dshape).astype("float32") + updates_np = np.random.uniform(size=ishape).astype("float32") + indices_np = np.random.randint(-dshape[axis], dshape[axis] - 1, ishape).astype("int64") + + ref_res = ref_scatter_add(data_np, indices_np, updates_np, axis) + # TODO(mbrookhart): expand testing when adding more backend schedules + for target, ctx in [("llvm", tvm.cpu())]: + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(data_np, indices_np, updates_np) + tvm.testing.assert_allclose( + op_res.asnumpy(), ref_res, rtol=1e-5) + + verify_scatter_add((10, ), (10, ), 0) + verify_scatter_add((10, 5), (10, 5), -2) + verify_scatter_add((10, 5), (10, 5), -1) + verify_scatter_add((10, 5), (3, 5), 0) + verify_scatter_add((12, 4), (7, 2), 1) + verify_scatter_add((2, 3, 4), (1, 3, 4), 0) + verify_scatter_add((2, 3, 4), (2, 1, 4), 1) + verify_scatter_add((2, 3, 4), (2, 3, 1), 2) + verify_scatter_add((2, 3, 4, 5), (1, 3, 4, 5), 0) + verify_scatter_add((6, 3, 4, 5), (2, 3, 4, 5), 1) + verify_scatter_add((2, 3, 8, 5), (2, 3, 1, 1), 2) + verify_scatter_add((16, 16, 4, 5), (16, 16, 4, 5), 3) + + def test_gather(): def verify_gather(data, axis, indices, ref_res): data = np.asarray(data, dtype='float32') diff --git a/topi/python/topi/__init__.py b/topi/python/topi/__init__.py index 56c3a740b8434..f308aa634ec3c 100644 --- a/topi/python/topi/__init__.py +++ b/topi/python/topi/__init__.py @@ -40,6 +40,7 @@ from .broadcast import * from .sort import * from .scatter import * +from .scatter_add import * from .argwhere import * from . import generic from . import nn diff --git a/topi/python/topi/generic/search.py b/topi/python/topi/generic/search.py index 895dadbd130c2..b3c8772046fd9 100644 --- a/topi/python/topi/generic/search.py +++ b/topi/python/topi/generic/search.py @@ -50,3 +50,19 @@ def schedule_scatter(outs): The computation schedule for the op. """ return _default_schedule(outs, False) + + +def schedule_scatter_add(outs): + """Schedule for scatter_add operator. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of scatter_add. + + Returns + ------- + s: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) diff --git a/topi/python/topi/scatter_add.py b/topi/python/topi/scatter_add.py new file mode 100644 index 0000000000000..046972b0f14e3 --- /dev/null +++ b/topi/python/topi/scatter_add.py @@ -0,0 +1,165 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks +"""Scatter Add operator""" +from tvm.te import hybrid + + +@hybrid.script +def _scatter_add_1d(data, indices, updates): + out = output_tensor(data.shape, data.dtype) + for i in range(data.shape[0]): + out[i] = data[i] + for i in range(indices.shape[0]): + out[indices[i] if indices[i] >= 0 else indices[i] + + data.shape[0]] += updates[i] + return out + + +@hybrid.script +def _scatter_add_2d(data, indices, updates, axis): + out = output_tensor(data.shape, data.dtype) + for i in const_range(data.shape[0]): + for j in const_range(data.shape[1]): + out[i, j] = data[i, j] + if axis == 0: + for i in range(indices.shape[0]): + for j in range(indices.shape[1]): + out[indices[i, j] if indices[i, j] >= + 0 else indices[i, j] + data.shape[axis], j] += updates[i, j] + else: + for i in range(indices.shape[0]): + for j in range(indices.shape[1]): + out[i, indices[i, j] if indices[i, j] >= + 0 else indices[i, j] + data.shape[axis]] += updates[i, j] + + return out + + +@hybrid.script +def _scatter_add_3d(data, indices, updates, axis): + out = output_tensor(data.shape, data.dtype) + for i in const_range(data.shape[0]): + for j in const_range(data.shape[1]): + for k in const_range(data.shape[2]): + out[i, j, k] = data[i, j, k] + if axis == 0: + for i in range(indices.shape[0]): + for j in range(indices.shape[1]): + for k in const_range(indices.shape[2]): + out[indices[i, j, k] if indices[i, j, k] >= + 0 else indices[i, j, k] + data.shape[axis], j, k] += updates[i, j, k] + elif axis == 1: + for i in range(indices.shape[0]): + for j in range(indices.shape[1]): + for k in const_range(indices.shape[2]): + out[i, indices[i, j, k] if indices[i, j, k] >= + 0 else indices[i, j, k] + data.shape[axis], k] += updates[i, j, k] + else: + for i in range(indices.shape[0]): + for j in range(indices.shape[1]): + for k in const_range(indices.shape[2]): + out[i, j, indices[i, j, k] if indices[i, j, k] >= + 0 else indices[i, j, k] + data.shape[axis]] += updates[i, j, k] + + return out + + +@hybrid.script +def _scatter_add_4d(data, indices, updates, axis): + out = output_tensor(data.shape, data.dtype) + for i in const_range(data.shape[0]): + for j in const_range(data.shape[1]): + for k in const_range(data.shape[2]): + for l in const_range(data.shape[3]): + out[i, j, k, l] = data[i, j, k, l] + + if axis == 0: + for i in range(indices.shape[0]): + for j in range(indices.shape[1]): + for k in const_range(indices.shape[2]): + for l in const_range(indices.shape[3]): + out[indices[i, j, k, l] if indices[i, j, k, l] >= + 0 else indices[i, j, k, l] + data.shape[axis], + j, k, l] += updates[i, j, k, l] + elif axis == 1: + for i in range(indices.shape[0]): + for j in range(indices.shape[1]): + for k in const_range(indices.shape[2]): + for l in const_range(indices.shape[3]): + out[i, + indices[i, j, k, l] if indices[i, j, k, l] >= + 0 else indices[i, j, k, l] + data.shape[axis], + k, l] += updates[i, j, k, l] + elif axis == 2: + for i in range(indices.shape[0]): + for j in range(indices.shape[1]): + for k in const_range(indices.shape[2]): + for l in const_range(indices.shape[3]): + out[i, j, + indices[i, j, k, l] if indices[i, j, k, l] >= + 0 else indices[i, j, k, l] + data.shape[axis], + l] += updates[i, j, k, l] + else: + for i in range(indices.shape[0]): + for j in range(indices.shape[1]): + for k in const_range(indices.shape[2]): + for l in const_range(indices.shape[3]): + out[i, j, k, + indices[i, j, k, l] if indices[i, j, k, l] >= + 0 else indices[i, j, k, l] + data.shape[axis] + ] += updates[i, j, k, l] + + return out + + +def scatter_add(data, indices, updates, axis=0): + """Update data by adding values in updates at positions defined by indices + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + indices : relay.Expr + The index locations to update. + + updates : relay.Expr + The values to update. + + axis : int + The axis to scatter_add on + + Returns + ------- + ret : relay.Expr + The computed result. + """ + if axis < 0: + axis += len(data.shape) + assert axis >= 0 + assert axis < len(data.shape) + + if len(data.shape) == 1: + return _scatter_add_1d(data, indices, updates) + if len(data.shape) == 2: + return _scatter_add_2d(data, indices, updates, axis) + if len(data.shape) == 3: + return _scatter_add_3d(data, indices, updates, axis) + if len(data.shape) == 4: + return _scatter_add_4d(data, indices, updates, axis) + raise ValueError("scatter_add only support for 1-4 dimensions")