diff --git a/include/tvm/relax/attrs/manipulate.h b/include/tvm/relax/attrs/manipulate.h index ef4265d73b4b..e53ba3c36e7f 100644 --- a/include/tvm/relax/attrs/manipulate.h +++ b/include/tvm/relax/attrs/manipulate.h @@ -164,6 +164,18 @@ struct ScatterElementsAttrs : public tvm::AttrsNode { "either \"update\", \"add\", \"mul\", \"mean\", \"min\" or \"max\"."); } }; // struct ScatterElementsAttrs + +/*! \brief Attributes used in scatter_nd operators */ +struct ScatterNDAttrs : public tvm::AttrsNode { + String reduction; + + TVM_DECLARE_ATTRS(ScatterNDAttrs, "relax.attrs.ScatterNDAttrs") { + TVM_ATTR_FIELD(reduction).set_default("update").describe( + "Accumulation mode of the ScatterND, " + "either \"update\", \"add\", \"mul\", \"min\" or \"max\"."); + } +}; // struct ScatterNDAttrs + } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index b9eb141bd14e..f1fa67546c2a 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -692,6 +692,36 @@ def _impl_v11(cls, bb, inputs, attr, params): return relax.op.scatter_elements(inputs[0], inputs[1], inputs[2], axis=axis) +class ScatterND(OnnxOpConverter): + """Convert an onnx ScatterND node into an equivalent Relax expression.""" + + @staticmethod + def _reduction_check(attr, valid_reductions: List[str]): + reduction = attr.get("reduction", None) + reduction = reduction or b"update" + reduction = reduction.decode("utf-8") + reduction = "update" if reduction == "none" else reduction + assert ( + reduction in valid_reductions + ), f"Only {valid_reductions} reductions are supported, but {reduction} is gotten" + + return reduction + + @classmethod + def _impl_v11(cls, bb, inputs, attr, params): + return relax.op.scatter_nd(inputs[0], inputs[1], inputs[2]) + + @classmethod + def _impl_v16(cls, bb, inputs, attr, params): + reduction = cls._reduction_check(attr, ["update", "add", "mul"]) + return relax.op.scatter_nd(inputs[0], inputs[1], inputs[2], reduction) + + @classmethod + def _impl_v18(cls, bb, inputs, attr, params): + reduction = cls._reduction_check(attr, ["update", "add", "mul", "min", "max"]) + return relax.op.scatter_nd(inputs[0], inputs[1], inputs[2], reduction) + + class Size(OnnxOpConverter): """Convert an onnx Size node into an equivalent Relax expression.""" @@ -2827,7 +2857,7 @@ def _get_convert_map(): # "GatherND": GatherND, "Scatter": Scatter, "ScatterElements": ScatterElements, - # "ScatterND": ScatterND, + "ScatterND": ScatterND, # "Compress": Compress, "Size": Size, # "EyeLike": EyeLike, diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index efd9997698ee..84b31ccec01e 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -93,6 +93,7 @@ repeat, reshape, scatter_elements, + scatter_nd, split, squeeze, tile, diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index da0a09cc7b51..1673a79b08c2 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -511,3 +511,42 @@ def scatter_elements( """ return _ffi_api.scatter_elements(data, indices, updates, axis, reduction) # type: ignore + + +def scatter_nd(data: Expr, indices: Expr, updates: Expr, reduction: str = "update") -> Expr: + """Scatter updates into an array according to indices. + + Parameters + ---------- + data: relax.Expr + The input data to be updated. + + indices: relax.Expr + The index positions to update in `data`. + + updates: relax.Expr + Values to replace to. + + reduction: str + Type of reduction to apply: update, add, mul, max, min. + It is "update" by default. + + Returns + ------- + result : relax.Expr + The result has the same shape as data. + + Examples + -------- + .. code-block:: python + + # inputs + data = [1, 2, 3, 4, 5, 6, 7, 8] + indices = [[4], [3], [1], [7]] + updates = [9, 10, 11, 12] + + # output + output = [1, 11, 3, 10, 9, 6, 7, 12] + + """ + return _ffi_api.scatter_nd(data, indices, updates, reduction) # type: ignore diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index 1efa78c069ad..105d763403af 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -168,6 +168,23 @@ def _scatter_elements(bb: BlockBuilder, call: Call) -> Expr: ) +@register_legalize("relax.scatter_nd") +def _scatter_nd(bb: BlockBuilder, call: Call) -> Expr: + # TODO(relax-team): Support native scatter_nd without te extern + def scatter_nd(data, indices, updates, reduction): + axes = list(range(len(indices.shape))) + indices = topi.transpose(indices, axes[-1:] + axes[:-1]) + return topi.scatter_nd(data, indices, updates, reduction) + + return bb.call_te( + scatter_nd, + call.args[0], + call.args[1], + call.args[2], + call.attrs.reduction, + ) + + @register_legalize("relax.layout_transform") def _layout_transform(bb: BlockBuilder, call: Call) -> Expr: def te_layout_transform(data, name): diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index e6ff35ebe56b..f7847e2af8ed 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -138,6 +138,7 @@ round, rsqrt, scatter_elements, + scatter_nd, shape_of, shape_to_tensor, sigmoid, @@ -738,6 +739,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "cumsum", "einsum", "scatter_elements", + "scatter_nd", "dataflow", "device", "divide", diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 2b1c6eafb652..ca7d0a0945bc 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -1531,5 +1531,139 @@ TVM_REGISTER_OP("relax.scatter_elements") .set_attr("FInferStructInfo", InferStructInfoScatterElements) .set_attr("FPurity", Bool(true)); +/* relax.scatter_nd */ +TVM_REGISTER_NODE_TYPE(ScatterNDAttrs); + +Expr scatter_nd(Expr data, Expr indices, Expr updates, String reduction) { + auto attrs = make_object(); + attrs->reduction = std::move(reduction); + static const Op& op = Op::Get("relax.scatter_nd"); + return Call(op, {data, indices, updates}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.scatter_nd").set_body_typed(scatter_nd); + +StructInfo InferStructInfoScatterND(const Call& call, const BlockBuilder& ctx) { + // `call->args` contains: [data, indices, updates] + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + ICHECK_EQ(call->args.size(), 3); + const auto* data_sinfo = GetStructInfoAs(call->args[0]); + const auto* indices_sinfo = GetStructInfoAs(call->args[1]); + const auto* updates_sinfo = GetStructInfoAs(call->args[2]); + + if (data_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "ScatterND op requires the input data to be a tensor. However, the given type is " + << call->args[0]->GetTypeKey()); + } + if (indices_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "ScatterND op requires the input indices to be a tensor. However, the given type is " + << call->args[1]->GetTypeKey()); + } + if (updates_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "ScatterND op requires the input updates to be a tensor. However, the given type is " + << call->args[2]->GetTypeKey()); + } + + if (data_sinfo->IsUnknownDtype() || updates_sinfo->IsUnknownDtype()) { + ctx->ReportFatal(Diagnostic::Error(call) + << "ScatterND op requires the input data and updates to have known dtype. " + "However, the given types are " + << "data: " << data_sinfo->dtype << ", updates: " << updates_sinfo->dtype); + } + + if (data_sinfo->dtype != updates_sinfo->dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << "ScatterND op requires the input data to have same type with updates. " + "However, the given types are " + << "data: " << data_sinfo->dtype << ", updates: " << updates_sinfo->dtype); + } + + if (indices_sinfo->IsUnknownDtype()) { + LOG(WARNING) << "Data type of indices has not been specified. Assume it has an integer type."; + } else if (!(indices_sinfo->dtype.is_int() || indices_sinfo->dtype.is_uint())) { + ctx->ReportFatal(Diagnostic::Error(call) + << "ScatterND op requires the input indices to have integer dtype. However, " + "the given indices dtype is " + << indices_sinfo->dtype); + } + + const auto* data_shape = data_sinfo->shape.as(); + const auto* indices_shape = indices_sinfo->shape.as(); + const auto* updates_shape = updates_sinfo->shape.as(); + + if (data_shape && indices_shape && updates_shape) { + const IntImmNode* k_dim = indices_shape->values[indices_sinfo->ndim - 1].as(); + if (!k_dim) { + ctx->ReportFatal(Diagnostic::Error(call) + << "ScatterND needs a static shape for the last axis of indices, got " + << indices_shape->values); + } + const size_t data_ndim = data_sinfo->ndim; + const size_t indices_ndim = indices_sinfo->ndim; + const size_t updates_ndim = updates_sinfo->ndim; + if (data_ndim + indices_ndim - k_dim->value - 1 != updates_ndim) { + ctx->ReportFatal(Diagnostic::Error(call) + << "ScatterND op requires the updates tensor to have the rank of " + "`data tensor + indices tensor - last axis of indices tensor - 1`. " + "However, the given shapes are " + << "data: " << ShapeExpr(data_shape->values) + << ", indices: " << ShapeExpr(indices_shape->values) + << ", updates: " << ShapeExpr(updates_shape->values)); + } + if (k_dim->value > static_cast(data_ndim)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "ScatterND op requires the last axis of indices tensor to be less than " + "or equal to the rank of data tensor. However, the given shapes are " + << "data: " << ShapeExpr(data_shape->values) + << ", indices: " << ShapeExpr(indices_shape->values)); + } + Array expected_updates_shape; + for (size_t i = 0; i < indices_ndim - 1; i++) { + expected_updates_shape.push_back(indices_shape->values[i]); + } + for (size_t i = k_dim->value; i < data_ndim; i++) { + expected_updates_shape.push_back(data_shape->values[i]); + } + auto check_shape = [&](const Array& expected, const Array& actual) { + if (expected.size() != actual.size()) { + return false; + } + for (size_t i = 0; i < expected.size(); i++) { + if (!analyzer->CanProve(expected[i] == actual[i])) { + return false; + } + } + return true; + }; + if (!check_shape(expected_updates_shape, updates_shape->values)) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "ScatterND op requires the updates tensor to have the shape with constraint: " + << "`updates.shape = indices.shape[:-1] + data.shape[K:]`, but got " + << "updates.shape: " << ShapeExpr(updates_shape->values) << ", indices.shape: " + << ShapeExpr(indices_shape->values) << ", data.shape: " << ShapeExpr(data_shape->values)); + } + } + if (data_shape) { + return TensorStructInfo(ShapeExpr(data_shape->values), data_sinfo->dtype, data_sinfo->vdevice); + } + return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice); +} + +TVM_REGISTER_OP("relax.scatter_nd") + .set_attrs_type() + .set_num_inputs(3) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("indices", "Tensor", "The indices tensor.") + .add_argument("updates", "Tensor", "The input tensor of updates.") + .set_attr("FInferStructInfo", InferStructInfoScatterND) + .set_attr("FPurity", Bool(true)); + } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index 68622f1359e0..e9fa1131e803 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -173,6 +173,39 @@ Expr tile(Expr data, Array repeats); */ Expr flip(Expr data, Integer axis); +/*! + * \brief Scatter updates into an array according to indices. + * \param data The input tensor. + * \param indices The index positions to update in `data`. + * \param updates The values to replace to. + * \param axis The axis along which to scatter the elements. + * \param reduction The reduction mode of the scatter elements, + * either "update", "add", "mul", "mean", "max" or "min". + * \return The computed result. + */ +Expr scatter_elements(Expr data, Expr indices, Expr updates, int axis, String reduction); + +/*! + * \brief Scatter updates into an array according to indices. + * \param data The input tensor to be updated. + * \param indices The index positions to update in `data`. + * \param updates The values to replace to. + * \param reduction The reduction mode of the scatter operation. + * Supported modes are: + * - "update": Replace the values at the indices with the update values. + * - "add": Add the update values to the existing values at the indices. + * - "mul": Multiply the existing values at the indices by the update values. + * - "max": Take the maximum of the existing value and the update value at each index. + * - "min": Take the minimum of the existing value and the update value at each index. + * \return The computed result tensor with the same shape as `data`. + * + * \note The shape of `indices` defines the shape of the scattered tensor. + * The last dimension of `indices` corresponds to the depth of each index vector. + * The shape of `updates` must match the shape of `indices` except for the last dimension, + * which must match the slice shape at each index. + */ +Expr scatter_nd(Expr data, Expr indices, Expr updates, String reduction); + } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 57f94c8442f7..9ac520c58e14 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -118,7 +118,6 @@ def check_correctness( tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model) # Legalize any relax ops into tensorir. tvm_model = relax.transform.LegalizeOps()(tvm_model) - print(tvm_model) # Separate model from parameters. tvm_model, params = relax.frontend.detach_params(tvm_model) @@ -523,6 +522,38 @@ def test_scatter(axis: int, name: str, opset: int): check_correctness(model, inputs={"indices": indices}, opset=opset) +@pytest.mark.parametrize("reduction", ["none", "add", "mul"]) +def test_scatter_nd(reduction): + def verify_scatter_nd(data_shape, indices_shape, updates_shape): + scatter_nd_node = helper.make_node( + "ScatterND", + ["data", "indices", "updates"], + ["output"], + reduction=reduction, + ) + + graph = helper.make_graph( + [scatter_nd_node], + "scatter_nd_test", + inputs=[ + helper.make_tensor_value_info("data", TensorProto.FLOAT, data_shape), + helper.make_tensor_value_info("indices", TensorProto.INT64, indices_shape), + helper.make_tensor_value_info("updates", TensorProto.FLOAT, updates_shape), + ], + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, data_shape)], + ) + + model = helper.make_model(graph, producer_name="scatter_nd_test") + + indices = np.random.choice(data_shape[0], indices_shape) + check_correctness(model, inputs={"indices": indices}, opset=16) + + verify_scatter_nd([8], [4, 1], [4]) + verify_scatter_nd([4, 4, 4], [2, 1], [2, 4, 4]) + verify_scatter_nd([4, 5, 6], [2, 3, 2], [2, 3, 6]) + verify_scatter_nd([10], [5, 1], [5]) + + def test_size(): test_node = helper.make_node("Size", ["x"], ["y"]) graph = helper.make_graph( diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py index ddb92725d438..e958b03e4ce6 100644 --- a/tests/python/relax/test_op_manipulate.py +++ b/tests/python/relax/test_op_manipulate.py @@ -45,6 +45,7 @@ def test_op_correctness(): assert relax.op.einsum(x, subscripts="ii").op == Op.get("relax.einsum") assert relax.op.flip(x, axis=1).op == Op.get("relax.flip") assert relax.op.scatter_elements(x, x, x).op == Op.get("relax.scatter_elements") + assert relax.op.scatter_nd(x, x, x).op == Op.get("relax.scatter_nd") def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): @@ -3352,5 +3353,29 @@ def test_scatter_elements_infer_struct_info_rank_shape_mismatch(): bb.normalize(relax.op.scatter_elements(d0, i0, u4)) +def test_scatter_nd_infer_struct_info(): + bb = relax.BlockBuilder() + + d0 = relax.Var("data", R.Tensor((8,), "float32")) + i0 = relax.Var("indices", R.Tensor((4, 1), "int64")) + u0 = relax.Var("updates", R.Tensor((4,), "float32")) + + _check_inference( + bb, + relax.op.scatter_nd(d0, i0, u0, "update"), + relax.TensorStructInfo((8,), dtype="float32"), + ) + + d1 = relax.Var("data", R.Tensor((4, 4, 4), "float32")) + i1 = relax.Var("indices", R.Tensor((2, 1), "int64")) + u1 = relax.Var("updates", R.Tensor((2, 4, 4), "float32")) + + _check_inference( + bb, + relax.op.scatter_nd(d1, i1, u1, "update"), + relax.TensorStructInfo((4, 4, 4), dtype="float32"), + ) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py b/tests/python/relax/test_transform_legalize_ops_manipulate.py index a0ecd3c73dc9..0565b7a5790a 100644 --- a/tests/python/relax/test_transform_legalize_ops_manipulate.py +++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -import pytest import tvm from tvm import relax from tvm.relax.transform import LegalizeOps @@ -1739,5 +1738,66 @@ def te_layout_transform( tvm.ir.assert_structural_equal(Expected, After) +def test_scatter_nd(): + + # fmt: off + @I.ir_module + class Before: + @R.function + def main( + data: R.Tensor((8,), "float32"), + indices: R.Tensor((4, 1), "int64"), + updates: R.Tensor((4,), "float32"), + ) -> R.Tensor((8,), "float32"): + gv: R.Tensor((8,), "float32") = R.scatter_nd(data, indices, updates, reduction="update") + return gv + + After = relax.transform.LegalizeOps()(Before) + + @I.ir_module + class Expected: + @R.function + def main( + data: R.Tensor((8,), "float32"), + indices: R.Tensor((4, 1), "int64"), + updates: R.Tensor((4,), "float32"), + ) -> R.Tensor((8,), "float32"): + gv = R.call_tir( + Expected.scatter_nd, (data, indices, updates), R.Tensor((8,), dtype="float32") + ) + return gv + + @T.prim_func(private=True) + def scatter_nd(var_data: T.handle, var_indices: T.handle, var_updates: T.handle, var_scatter_nd_generic: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + data = T.match_buffer(var_data, (T.int64(8),), offset_factor=1) + indices = T.match_buffer(var_indices, (T.int64(4), T.int64(1)), "int64") + updates = T.match_buffer(var_updates, (T.int64(4),), offset_factor=1) + out_buf = T.match_buffer(var_scatter_nd_generic, (T.int64(8),)) + with T.block("root"): + T.reads() + T.writes() + T_transpose = T.alloc_buffer((T.int64(1), T.int64(4)), "int64") + for ax0 in range(T.int64(1)): + for ax1 in range(T.int64(4)): + with T.block("T_transpose"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(T.int64(4), ax1) + T.reads(indices[v_ax1, v_ax0]) + T.writes(T_transpose[v_ax0, v_ax1]) + T_transpose[v_ax0, v_ax1] = indices[v_ax1, v_ax0] + with T.block("scatter_nd_generic"): + T.reads() + T.writes() + for i in range(T.int64(8)): + out_buf[i] = data[i] + for j in range(T.int64(4)): + for k in T.parallel(T.int64(1)): + out_buf[k + T_transpose[j // T.int64(4), j % T.int64(4)]] = updates[j + k] + + # fmt: on + tvm.ir.assert_structural_equal(After, Expected) + + if __name__ == "__main__": tvm.testing.main()