Skip to content

Commit

Permalink
[Relax] Add scatter_nd op support
Browse files Browse the repository at this point in the history
Add relax scatter_nd op support and ONNX frontend support.
  • Loading branch information
Hzfengsy committed Oct 7, 2024
1 parent accd582 commit 04cda4e
Show file tree
Hide file tree
Showing 10 changed files with 354 additions and 3 deletions.
12 changes: 12 additions & 0 deletions include/tvm/relax/attrs/manipulate.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,18 @@ struct ScatterElementsAttrs : public tvm::AttrsNode<ScatterElementsAttrs> {
"either \"update\", \"add\", \"mul\", \"mean\", \"min\" or \"max\".");
}
}; // struct ScatterElementsAttrs

/*! \brief Attributes used in scatter_nd operators */
struct ScatterNDAttrs : public tvm::AttrsNode<ScatterNDAttrs> {
String reduction;

TVM_DECLARE_ATTRS(ScatterNDAttrs, "relax.attrs.ScatterNDAttrs") {
TVM_ATTR_FIELD(reduction).set_default("update").describe(
"Accumulation mode of the ScatterND, "
"either \"update\", \"add\", \"mul\", \"min\" or \"max\".");
}
}; // struct ScatterNDAttrs

} // namespace relax
} // namespace tvm

Expand Down
32 changes: 31 additions & 1 deletion python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,36 @@ def _impl_v11(cls, bb, inputs, attr, params):
return relax.op.scatter_elements(inputs[0], inputs[1], inputs[2], axis=axis)


class ScatterND(OnnxOpConverter):
"""Convert an onnx ScatterND node into an equivalent Relax expression."""

@staticmethod
def _reduction_check(attr, valid_reductions: List[str]):
reduction = attr.get("reduction", None)
reduction = reduction or b"update"
reduction = reduction.decode("utf-8")
reduction = "update" if reduction == "none" else reduction
assert (
reduction in valid_reductions
), f"Only {valid_reductions} reductions are supported, but {reduction} is gotten"

return reduction

@classmethod
def _impl_v11(cls, bb, inputs, attr, params):
return relax.op.scatter_nd(inputs[0], inputs[1], inputs[2])

@classmethod
def _impl_v16(cls, bb, inputs, attr, params):
reduction = cls._reduction_check(attr, ["update", "add", "mul"])
return relax.op.scatter_nd(inputs[0], inputs[1], inputs[2], reduction)

@classmethod
def _impl_v18(cls, bb, inputs, attr, params):
reduction = cls._reduction_check(attr, ["update", "add", "mul", "min", "max"])
return relax.op.scatter_nd(inputs[0], inputs[1], inputs[2], reduction)


class Size(OnnxOpConverter):
"""Convert an onnx Size node into an equivalent Relax expression."""

Expand Down Expand Up @@ -2729,7 +2759,7 @@ def _get_convert_map():
# "GatherND": GatherND,
"Scatter": Scatter,
"ScatterElements": ScatterElements,
# "ScatterND": ScatterND,
"ScatterND": ScatterND,
# "Compress": Compress,
"Size": Size,
# "EyeLike": EyeLike,
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
repeat,
reshape,
scatter_elements,
scatter_nd,
split,
squeeze,
tile,
Expand Down
39 changes: 39 additions & 0 deletions python/tvm/relax/op/manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,3 +511,42 @@ def scatter_elements(
"""
return _ffi_api.scatter_elements(data, indices, updates, axis, reduction) # type: ignore


def scatter_nd(data: Expr, indices: Expr, updates: Expr, reduction: str = "update") -> Expr:
"""Scatter updates into an array according to indices.
Parameters
----------
data: relax.Expr
The input data to be updated.
indices: relax.Expr
The index positions to update in `data`.
updates: relax.Expr
Values to replace to.
reduction: str
Type of reduction to apply: update, add, mul, max, min.
It is "update" by default.
Returns
-------
result : relax.Expr
The result has the same size as data, and the same shape as data
Examples
--------
.. code-block:: python
# inputs
data = [1, 2, 3, 4, 5, 6, 7, 8]
indices = [[4], [3], [1], [7]]
updates = [9, 10, 11, 12]
# output
output = [1, 11, 3, 10, 9, 6, 7, 12]
"""
return _ffi_api.scatter_nd(data, indices, updates, reduction) # type: ignore
17 changes: 17 additions & 0 deletions python/tvm/relax/transform/legalize_ops/manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,23 @@ def _scatter_elements(bb: BlockBuilder, call: Call) -> Expr:
)


@register_legalize("relax.scatter_nd")
def _scatter_nd(bb: BlockBuilder, call: Call) -> Expr:
# TODO(relax-team): Support native scatter_nd without te extern
def scatter_nd(data, indices, updates, reduction):
axes = list(range(len(indices.shape)))
indices = topi.transpose(indices, axes[-1:] + axes[:-1])
return topi.scatter_nd(data, indices, updates, reduction)

return bb.call_te(
scatter_nd,
call.args[0],
call.args[1],
call.args[2],
call.attrs.reduction,
)


@register_legalize("relax.layout_transform")
def _layout_transform(bb: BlockBuilder, call: Call) -> Expr:
def te_layout_transform(data, name):
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@
round,
rsqrt,
scatter_elements,
scatter_nd,
shape_of,
shape_to_tensor,
sigmoid,
Expand Down Expand Up @@ -736,6 +737,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"cumsum",
"einsum",
"scatter_elements",
"scatter_nd",
"dataflow",
"device",
"divide",
Expand Down
134 changes: 134 additions & 0 deletions src/relax/op/tensor/manipulate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1531,5 +1531,139 @@ TVM_REGISTER_OP("relax.scatter_elements")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoScatterElements)
.set_attr<Bool>("FPurity", Bool(true));

/* relax.scatter_nd */
TVM_REGISTER_NODE_TYPE(ScatterNDAttrs);

Expr scatter_nd(Expr data, Expr indices, Expr updates, String reduction) {
auto attrs = make_object<ScatterNDAttrs>();
attrs->reduction = std::move(reduction);
static const Op& op = Op::Get("relax.scatter_nd");
return Call(op, {data, indices, updates}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relax.op.scatter_nd").set_body_typed(scatter_nd);

StructInfo InferStructInfoScatterND(const Call& call, const BlockBuilder& ctx) {
// `call->args` contains: [data, indices, updates]
arith::Analyzer* analyzer = ctx->GetAnalyzer();
ICHECK_EQ(call->args.size(), 3);
const auto* data_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
const auto* indices_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[1]);
const auto* updates_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[2]);

if (data_sinfo == nullptr) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "ScatterND op requires the input data to be a tensor. However, the given type is "
<< call->args[0]->GetTypeKey());
}
if (indices_sinfo == nullptr) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "ScatterND op requires the input indices to be a tensor. However, the given type is "
<< call->args[1]->GetTypeKey());
}
if (updates_sinfo == nullptr) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "ScatterND op requires the input updates to be a tensor. However, the given type is "
<< call->args[2]->GetTypeKey());
}

if (data_sinfo->IsUnknownDtype() || updates_sinfo->IsUnknownDtype()) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "ScatterND op requires the input data and updates to have known dtype. "
"However, the given types are "
<< "data: " << data_sinfo->dtype << ", updates: " << updates_sinfo->dtype);
}

if (data_sinfo->dtype != updates_sinfo->dtype) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "ScatterND op requires the input data to have same type with updates. "
"However, the given types are "
<< "data: " << data_sinfo->dtype << ", updates: " << updates_sinfo->dtype);
}

if (indices_sinfo->IsUnknownDtype()) {
LOG(WARNING) << "Data type of indices has not been specified. Assume it has an integer type.";
} else if (!(indices_sinfo->dtype.is_int() || indices_sinfo->dtype.is_uint())) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "ScatterND op requires the input indices to have integer dtype. However, "
"the given indices dtype is "
<< indices_sinfo->dtype);
}

const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
const auto* indices_shape = indices_sinfo->shape.as<ShapeExprNode>();
const auto* updates_shape = updates_sinfo->shape.as<ShapeExprNode>();

if (data_shape && indices_shape && updates_shape) {
const IntImmNode* k_dim = indices_shape->values[indices_sinfo->ndim - 1].as<IntImmNode>();
if (!k_dim) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "ScatterND needs a static shape for the last axis of indices, got "
<< indices_shape->values);
}
const size_t data_ndim = data_sinfo->ndim;
const size_t indices_ndim = indices_sinfo->ndim;
const size_t updates_ndim = updates_sinfo->ndim;
if (data_ndim + indices_ndim - k_dim->value - 1 != updates_ndim) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "ScatterND op requires the updates tensor to have the rank of "
"`data tensor + indices tensor - last axis of indices tensor - 1`. "
"However, the given shapes are "
<< "data: " << ShapeExpr(data_shape->values)
<< ", indices: " << ShapeExpr(indices_shape->values)
<< ", updates: " << ShapeExpr(updates_shape->values));
}
if (k_dim->value > static_cast<int>(data_ndim)) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "ScatterND op requires the last axis of indices tensor to be less than "
"or equal to the rank of data tensor. However, the given shapes are "
<< "data: " << ShapeExpr(data_shape->values)
<< ", indices: " << ShapeExpr(indices_shape->values));
}
Array<PrimExpr> expected_updates_shape;
for (size_t i = 0; i < indices_ndim - 1; i++) {
expected_updates_shape.push_back(indices_shape->values[i]);
}
for (size_t i = k_dim->value; i < data_ndim; i++) {
expected_updates_shape.push_back(data_shape->values[i]);
}
auto check_shape = [&](const Array<PrimExpr>& expected, const Array<PrimExpr>& actual) {
if (expected.size() != actual.size()) {
return false;
}
for (size_t i = 0; i < expected.size(); i++) {
if (!analyzer->CanProve(expected[i] == actual[i])) {
return false;
}
}
return true;
};
if (!check_shape(expected_updates_shape, updates_shape->values)) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "ScatterND op requires the updates tensor to have the shape with constraint: "
<< "`updates.shape = indices.shape[:-1] + data.shape[K:]`, but got "
<< "updates.shape: " << ShapeExpr(updates_shape->values) << ", indices.shape: "
<< ShapeExpr(indices_shape->values) << ", data.shape: " << ShapeExpr(data_shape->values));
}
}
if (data_shape) {
return TensorStructInfo(ShapeExpr(data_shape->values), data_sinfo->dtype, data_sinfo->vdevice);
}
return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice);
}

TVM_REGISTER_OP("relax.scatter_nd")
.set_attrs_type<ScatterNDAttrs>()
.set_num_inputs(3)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("indices", "Tensor", "The indices tensor.")
.add_argument("updates", "Tensor", "The input tensor of updates.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoScatterND)
.set_attr<Bool>("FPurity", Bool(true));

} // namespace relax
} // namespace tvm
33 changes: 32 additions & 1 deletion tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ def check_correctness(
tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model)
# Legalize any relax ops into tensorir.
tvm_model = relax.transform.LegalizeOps()(tvm_model)
print(tvm_model)

# Separate model from parameters.
tvm_model, params = relax.frontend.detach_params(tvm_model)
Expand Down Expand Up @@ -487,6 +486,38 @@ def test_scatter(axis: int, name: str, opset: int):
check_correctness(model, inputs={"indices": indices}, opset=opset)


@pytest.mark.parametrize("reduction", ["none", "add", "mul"])
def test_scatter_nd(reduction):
def verify_scatter_nd(data_shape, indices_shape, updates_shape):
scatter_nd_node = helper.make_node(
"ScatterND",
["data", "indices", "updates"],
["output"],
reduction=reduction,
)

graph = helper.make_graph(
[scatter_nd_node],
"scatter_nd_test",
inputs=[
helper.make_tensor_value_info("data", TensorProto.FLOAT, data_shape),
helper.make_tensor_value_info("indices", TensorProto.INT64, indices_shape),
helper.make_tensor_value_info("updates", TensorProto.FLOAT, updates_shape),
],
outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, data_shape)],
)

model = helper.make_model(graph, producer_name="scatter_nd_test")

indices = np.random.choice(data_shape[0], indices_shape)
check_correctness(model, inputs={"indices": indices}, opset=16)

verify_scatter_nd([8], [4, 1], [4])
verify_scatter_nd([4, 4, 4], [2, 1], [2, 4, 4])
verify_scatter_nd([4, 5, 6], [2, 3, 2], [2, 3, 6])
verify_scatter_nd([10], [5, 1], [5])


def test_size():
test_node = helper.make_node("Size", ["x"], ["y"])
graph = helper.make_graph(
Expand Down
25 changes: 25 additions & 0 deletions tests/python/relax/test_op_manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def test_op_correctness():
assert relax.op.einsum(x, subscripts="ii").op == Op.get("relax.einsum")
assert relax.op.flip(x, axis=1).op == Op.get("relax.flip")
assert relax.op.scatter_elements(x, x, x).op == Op.get("relax.scatter_elements")
assert relax.op.scatter_nd(x, x, x).op == Op.get("relax.scatter_nd")


def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo):
Expand Down Expand Up @@ -3352,5 +3353,29 @@ def test_scatter_elements_infer_struct_info_rank_shape_mismatch():
bb.normalize(relax.op.scatter_elements(d0, i0, u4))


def test_scatter_nd_infer_struct_info():
bb = relax.BlockBuilder()

d0 = relax.Var("data", R.Tensor((8,), "float32"))
i0 = relax.Var("indices", R.Tensor((4, 1), "int64"))
u0 = relax.Var("updates", R.Tensor((4,), "float32"))

_check_inference(
bb,
relax.op.scatter_nd(d0, i0, u0, "update"),
relax.TensorStructInfo((8,), dtype="float32"),
)

d1 = relax.Var("data", R.Tensor((4, 4, 4), "float32"))
i1 = relax.Var("indices", R.Tensor((2, 1), "int64"))
u1 = relax.Var("updates", R.Tensor((2, 4, 4), "float32"))

_check_inference(
bb,
relax.op.scatter_nd(d1, i1, u1, "update"),
relax.TensorStructInfo((4, 4, 4), dtype="float32"),
)


if __name__ == "__main__":
tvm.testing.main()
Loading

0 comments on commit 04cda4e

Please sign in to comment.