Skip to content

Commit

Permalink
[Relax] Add scatter_nd op support (#17449)
Browse files Browse the repository at this point in the history
Add relax scatter_nd op support and ONNX frontend support.
  • Loading branch information
Hzfengsy authored Oct 10, 2024
1 parent d50ec23 commit 910ee0e
Show file tree
Hide file tree
Showing 11 changed files with 387 additions and 3 deletions.
12 changes: 12 additions & 0 deletions include/tvm/relax/attrs/manipulate.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,18 @@ struct ScatterElementsAttrs : public tvm::AttrsNode<ScatterElementsAttrs> {
"either \"update\", \"add\", \"mul\", \"mean\", \"min\" or \"max\".");
}
}; // struct ScatterElementsAttrs

/*! \brief Attributes used in scatter_nd operators */
struct ScatterNDAttrs : public tvm::AttrsNode<ScatterNDAttrs> {
String reduction;

TVM_DECLARE_ATTRS(ScatterNDAttrs, "relax.attrs.ScatterNDAttrs") {
TVM_ATTR_FIELD(reduction).set_default("update").describe(
"Accumulation mode of the ScatterND, "
"either \"update\", \"add\", \"mul\", \"min\" or \"max\".");
}
}; // struct ScatterNDAttrs

} // namespace relax
} // namespace tvm

Expand Down
32 changes: 31 additions & 1 deletion python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,36 @@ def _impl_v11(cls, bb, inputs, attr, params):
return relax.op.scatter_elements(inputs[0], inputs[1], inputs[2], axis=axis)


class ScatterND(OnnxOpConverter):
"""Convert an onnx ScatterND node into an equivalent Relax expression."""

@staticmethod
def _reduction_check(attr, valid_reductions: List[str]):
reduction = attr.get("reduction", None)
reduction = reduction or b"update"
reduction = reduction.decode("utf-8")
reduction = "update" if reduction == "none" else reduction
assert (
reduction in valid_reductions
), f"Only {valid_reductions} reductions are supported, but {reduction} is gotten"

return reduction

@classmethod
def _impl_v11(cls, bb, inputs, attr, params):
return relax.op.scatter_nd(inputs[0], inputs[1], inputs[2])

@classmethod
def _impl_v16(cls, bb, inputs, attr, params):
reduction = cls._reduction_check(attr, ["update", "add", "mul"])
return relax.op.scatter_nd(inputs[0], inputs[1], inputs[2], reduction)

@classmethod
def _impl_v18(cls, bb, inputs, attr, params):
reduction = cls._reduction_check(attr, ["update", "add", "mul", "min", "max"])
return relax.op.scatter_nd(inputs[0], inputs[1], inputs[2], reduction)


class Size(OnnxOpConverter):
"""Convert an onnx Size node into an equivalent Relax expression."""

Expand Down Expand Up @@ -2827,7 +2857,7 @@ def _get_convert_map():
# "GatherND": GatherND,
"Scatter": Scatter,
"ScatterElements": ScatterElements,
# "ScatterND": ScatterND,
"ScatterND": ScatterND,
# "Compress": Compress,
"Size": Size,
# "EyeLike": EyeLike,
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
repeat,
reshape,
scatter_elements,
scatter_nd,
split,
squeeze,
tile,
Expand Down
39 changes: 39 additions & 0 deletions python/tvm/relax/op/manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,3 +511,42 @@ def scatter_elements(
"""
return _ffi_api.scatter_elements(data, indices, updates, axis, reduction) # type: ignore


def scatter_nd(data: Expr, indices: Expr, updates: Expr, reduction: str = "update") -> Expr:
"""Scatter updates into an array according to indices.
Parameters
----------
data: relax.Expr
The input data to be updated.
indices: relax.Expr
The index positions to update in `data`.
updates: relax.Expr
Values to replace to.
reduction: str
Type of reduction to apply: update, add, mul, max, min.
It is "update" by default.
Returns
-------
result : relax.Expr
The result has the same shape as data.
Examples
--------
.. code-block:: python
# inputs
data = [1, 2, 3, 4, 5, 6, 7, 8]
indices = [[4], [3], [1], [7]]
updates = [9, 10, 11, 12]
# output
output = [1, 11, 3, 10, 9, 6, 7, 12]
"""
return _ffi_api.scatter_nd(data, indices, updates, reduction) # type: ignore
17 changes: 17 additions & 0 deletions python/tvm/relax/transform/legalize_ops/manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,23 @@ def _scatter_elements(bb: BlockBuilder, call: Call) -> Expr:
)


@register_legalize("relax.scatter_nd")
def _scatter_nd(bb: BlockBuilder, call: Call) -> Expr:
# TODO(relax-team): Support native scatter_nd without te extern
def scatter_nd(data, indices, updates, reduction):
axes = list(range(len(indices.shape)))
indices = topi.transpose(indices, axes[-1:] + axes[:-1])
return topi.scatter_nd(data, indices, updates, reduction)

return bb.call_te(
scatter_nd,
call.args[0],
call.args[1],
call.args[2],
call.attrs.reduction,
)


@register_legalize("relax.layout_transform")
def _layout_transform(bb: BlockBuilder, call: Call) -> Expr:
def te_layout_transform(data, name):
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@
round,
rsqrt,
scatter_elements,
scatter_nd,
shape_of,
shape_to_tensor,
sigmoid,
Expand Down Expand Up @@ -738,6 +739,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"cumsum",
"einsum",
"scatter_elements",
"scatter_nd",
"dataflow",
"device",
"divide",
Expand Down
134 changes: 134 additions & 0 deletions src/relax/op/tensor/manipulate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1531,5 +1531,139 @@ TVM_REGISTER_OP("relax.scatter_elements")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoScatterElements)
.set_attr<Bool>("FPurity", Bool(true));

/* relax.scatter_nd */
TVM_REGISTER_NODE_TYPE(ScatterNDAttrs);

Expr scatter_nd(Expr data, Expr indices, Expr updates, String reduction) {
auto attrs = make_object<ScatterNDAttrs>();
attrs->reduction = std::move(reduction);
static const Op& op = Op::Get("relax.scatter_nd");
return Call(op, {data, indices, updates}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relax.op.scatter_nd").set_body_typed(scatter_nd);

StructInfo InferStructInfoScatterND(const Call& call, const BlockBuilder& ctx) {
// `call->args` contains: [data, indices, updates]
arith::Analyzer* analyzer = ctx->GetAnalyzer();
ICHECK_EQ(call->args.size(), 3);
const auto* data_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
const auto* indices_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[1]);
const auto* updates_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[2]);

if (data_sinfo == nullptr) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "ScatterND op requires the input data to be a tensor. However, the given type is "
<< call->args[0]->GetTypeKey());
}
if (indices_sinfo == nullptr) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "ScatterND op requires the input indices to be a tensor. However, the given type is "
<< call->args[1]->GetTypeKey());
}
if (updates_sinfo == nullptr) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "ScatterND op requires the input updates to be a tensor. However, the given type is "
<< call->args[2]->GetTypeKey());
}

if (data_sinfo->IsUnknownDtype() || updates_sinfo->IsUnknownDtype()) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "ScatterND op requires the input data and updates to have known dtype. "
"However, the given types are "
<< "data: " << data_sinfo->dtype << ", updates: " << updates_sinfo->dtype);
}

if (data_sinfo->dtype != updates_sinfo->dtype) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "ScatterND op requires the input data to have same type with updates. "
"However, the given types are "
<< "data: " << data_sinfo->dtype << ", updates: " << updates_sinfo->dtype);
}

if (indices_sinfo->IsUnknownDtype()) {
LOG(WARNING) << "Data type of indices has not been specified. Assume it has an integer type.";
} else if (!(indices_sinfo->dtype.is_int() || indices_sinfo->dtype.is_uint())) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "ScatterND op requires the input indices to have integer dtype. However, "
"the given indices dtype is "
<< indices_sinfo->dtype);
}

const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
const auto* indices_shape = indices_sinfo->shape.as<ShapeExprNode>();
const auto* updates_shape = updates_sinfo->shape.as<ShapeExprNode>();

if (data_shape && indices_shape && updates_shape) {
const IntImmNode* k_dim = indices_shape->values[indices_sinfo->ndim - 1].as<IntImmNode>();
if (!k_dim) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "ScatterND needs a static shape for the last axis of indices, got "
<< indices_shape->values);
}
const size_t data_ndim = data_sinfo->ndim;
const size_t indices_ndim = indices_sinfo->ndim;
const size_t updates_ndim = updates_sinfo->ndim;
if (data_ndim + indices_ndim - k_dim->value - 1 != updates_ndim) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "ScatterND op requires the updates tensor to have the rank of "
"`data tensor + indices tensor - last axis of indices tensor - 1`. "
"However, the given shapes are "
<< "data: " << ShapeExpr(data_shape->values)
<< ", indices: " << ShapeExpr(indices_shape->values)
<< ", updates: " << ShapeExpr(updates_shape->values));
}
if (k_dim->value > static_cast<int>(data_ndim)) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "ScatterND op requires the last axis of indices tensor to be less than "
"or equal to the rank of data tensor. However, the given shapes are "
<< "data: " << ShapeExpr(data_shape->values)
<< ", indices: " << ShapeExpr(indices_shape->values));
}
Array<PrimExpr> expected_updates_shape;
for (size_t i = 0; i < indices_ndim - 1; i++) {
expected_updates_shape.push_back(indices_shape->values[i]);
}
for (size_t i = k_dim->value; i < data_ndim; i++) {
expected_updates_shape.push_back(data_shape->values[i]);
}
auto check_shape = [&](const Array<PrimExpr>& expected, const Array<PrimExpr>& actual) {
if (expected.size() != actual.size()) {
return false;
}
for (size_t i = 0; i < expected.size(); i++) {
if (!analyzer->CanProve(expected[i] == actual[i])) {
return false;
}
}
return true;
};
if (!check_shape(expected_updates_shape, updates_shape->values)) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "ScatterND op requires the updates tensor to have the shape with constraint: "
<< "`updates.shape = indices.shape[:-1] + data.shape[K:]`, but got "
<< "updates.shape: " << ShapeExpr(updates_shape->values) << ", indices.shape: "
<< ShapeExpr(indices_shape->values) << ", data.shape: " << ShapeExpr(data_shape->values));
}
}
if (data_shape) {
return TensorStructInfo(ShapeExpr(data_shape->values), data_sinfo->dtype, data_sinfo->vdevice);
}
return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice);
}

TVM_REGISTER_OP("relax.scatter_nd")
.set_attrs_type<ScatterNDAttrs>()
.set_num_inputs(3)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("indices", "Tensor", "The indices tensor.")
.add_argument("updates", "Tensor", "The input tensor of updates.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoScatterND)
.set_attr<Bool>("FPurity", Bool(true));

} // namespace relax
} // namespace tvm
33 changes: 33 additions & 0 deletions src/relax/op/tensor/manipulate.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,39 @@ Expr tile(Expr data, Array<Integer> repeats);
*/
Expr flip(Expr data, Integer axis);

/*!
* \brief Scatter updates into an array according to indices.
* \param data The input tensor.
* \param indices The index positions to update in `data`.
* \param updates The values to replace to.
* \param axis The axis along which to scatter the elements.
* \param reduction The reduction mode of the scatter elements,
* either "update", "add", "mul", "mean", "max" or "min".
* \return The computed result.
*/
Expr scatter_elements(Expr data, Expr indices, Expr updates, int axis, String reduction);

/*!
* \brief Scatter updates into an array according to indices.
* \param data The input tensor to be updated.
* \param indices The index positions to update in `data`.
* \param updates The values to replace to.
* \param reduction The reduction mode of the scatter operation.
* Supported modes are:
* - "update": Replace the values at the indices with the update values.
* - "add": Add the update values to the existing values at the indices.
* - "mul": Multiply the existing values at the indices by the update values.
* - "max": Take the maximum of the existing value and the update value at each index.
* - "min": Take the minimum of the existing value and the update value at each index.
* \return The computed result tensor with the same shape as `data`.
*
* \note The shape of `indices` defines the shape of the scattered tensor.
* The last dimension of `indices` corresponds to the depth of each index vector.
* The shape of `updates` must match the shape of `indices` except for the last dimension,
* which must match the slice shape at each index.
*/
Expr scatter_nd(Expr data, Expr indices, Expr updates, String reduction);

} // namespace relax
} // namespace tvm

Expand Down
33 changes: 32 additions & 1 deletion tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ def check_correctness(
tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model)
# Legalize any relax ops into tensorir.
tvm_model = relax.transform.LegalizeOps()(tvm_model)
print(tvm_model)

# Separate model from parameters.
tvm_model, params = relax.frontend.detach_params(tvm_model)
Expand Down Expand Up @@ -523,6 +522,38 @@ def test_scatter(axis: int, name: str, opset: int):
check_correctness(model, inputs={"indices": indices}, opset=opset)


@pytest.mark.parametrize("reduction", ["none", "add", "mul"])
def test_scatter_nd(reduction):
def verify_scatter_nd(data_shape, indices_shape, updates_shape):
scatter_nd_node = helper.make_node(
"ScatterND",
["data", "indices", "updates"],
["output"],
reduction=reduction,
)

graph = helper.make_graph(
[scatter_nd_node],
"scatter_nd_test",
inputs=[
helper.make_tensor_value_info("data", TensorProto.FLOAT, data_shape),
helper.make_tensor_value_info("indices", TensorProto.INT64, indices_shape),
helper.make_tensor_value_info("updates", TensorProto.FLOAT, updates_shape),
],
outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, data_shape)],
)

model = helper.make_model(graph, producer_name="scatter_nd_test")

indices = np.random.choice(data_shape[0], indices_shape)
check_correctness(model, inputs={"indices": indices}, opset=16)

verify_scatter_nd([8], [4, 1], [4])
verify_scatter_nd([4, 4, 4], [2, 1], [2, 4, 4])
verify_scatter_nd([4, 5, 6], [2, 3, 2], [2, 3, 6])
verify_scatter_nd([10], [5, 1], [5])


def test_size():
test_node = helper.make_node("Size", ["x"], ["y"])
graph = helper.make_graph(
Expand Down
Loading

0 comments on commit 910ee0e

Please sign in to comment.