Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

udpate scatter spec #7086

Merged
merged 13 commits into from
Sep 13, 2021
77 changes: 52 additions & 25 deletions docs/ops/movement/ScatterUpdate_3.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,82 +7,80 @@
**Short description**: *ScatterUpdate* creates a copy of the first input tensor with updated elements specified with second and third input tensors.

**Detailed description**: *ScatterUpdate* creates a copy of the first input tensor with updated elements in positions specified with `indices` input
and values specified with `updates` tensor starting from the dimension with index `axis`. For the `data` tensor of shape `[d_0, d_1, ..., d_n]`,
`indices` tensor of shape `[i_0, i_1, ..., i_k]` and `updates` tensor of shape
`[d_0, d_1, ... d_(axis - 1), i_0, i_1, ..., i_k, d_(axis + 1), ..., d_n]` the operation computes
and values specified with `updates` tensor starting from the dimension with index `axis`. For the `data` tensor of shape \f$[d_0,\;d_1,\;\dots,\;d_n]\f$,
`indices` tensor of shape \f$[i_0,\;i_1,\;\dots,\;i_k]\f$ and `updates` tensor of shape
\f$[d_0,\;d_1,\;\dots,\;d_{axis - 1},\;i_0,\;i_1,\;\dots,\;i_k,\;d_{axis + 1},\;\dots, d_n]\f$ the operation computes
for each `m, n, ..., p` of the `indices` tensor indices:

```
data[..., indices[m, n, ..., p], ...] = updates[..., m, n, ..., p, ...]
```

where first `...` in the `data` corresponds to first `axis` dimensions, last `...` in the `data` corresponds to the
\f[data[\dots,\;indices[m,\;n,\;\dots,\;p],\;\dots] = updates[\dots,\;m,\;n,\;\dots,\;p,\;\dots]\f]

where first \f$\dots\f$ in the `data` corresponds to \f$[d_0,\;\dots,\;d_{axis - 1}]\f$ dimensions, last\f$\dots\f$ in the `data` corresponds to the
`rank(data) - (axis + 1)` dimensions.

Several examples for case when `axis = 0`:
1. `indices` is a 0D tensor: `data[indices, ...] = updates[...]`
2. `indices` is a 1D tensor (for each `i`): `data[indices[i], ...] = updates[i, ...]`
3. `indices` is a ND tensor (for each `i, ..., j`): `data[indices[i, ..., j], ...] = updates[i, ..., j, ...]`

This operation is similar to TensorFlow* operation [ScatterUpdate](https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/scatter_update)
but allows scattering for the arbitrary axis.
1. `indices` is a \f$0\f$D tensor: \f$data[indices,\;\dots] = updates[\dots]\f$
2. `indices` is a \f$1\f$D tensor (\f$\forall_{i}\f$): \f$data[indices[i],\;\dots] = updates[i,\;\dots]\f$
3. `indices` is a \f$N\f$D tensor (\f$\forall_{i,\;\dots,\;j}\f$): \f$data[indices[i],\;\dots,\;j],\;\dots] = updates[i,\;\dots,\;j,\;\dots]\f$

**Attributes**: *ScatterUpdate* does not have attributes.

**Inputs**:

* **1**: `data` tensor of arbitrary rank `r` and of type *T*. **Required.**
* **1**: `data` tensor of arbitrary rank `r` and of type *T_NUMERIC*. **Required.**
bszmelcz marked this conversation as resolved.
Show resolved Hide resolved

* **2**: `indices` tensor with indices of type *T_IND*.
All index values are expected to be within bounds `[0, s - 1]` along axis of size `s`. If multiple indices point to the
bszmelcz marked this conversation as resolved.
Show resolved Hide resolved
same output location then the order of updating the values is undefined. If an index points to non-existing output
bszmelcz marked this conversation as resolved.
Show resolved Hide resolved
tensor element or is negative then an exception is raised. **Required.**
bszmelcz marked this conversation as resolved.
Show resolved Hide resolved

* **3**: `updates` tensor of type *T*. **Required.**
* **3**: `updates` tensor of type *T_NUMERIC* and rank equal to `rank(indices) + rank(data) - 1` **Required.**

* **4**: `axis` tensor with scalar or 1D tensor with one element of type *T_AXIS* specifying axis for scatter.
The value can be in range `[-r, r - 1]` where `r` is the rank of `data`. **Required.**
The value can be in range `[ -r, r - 1]` where `r` is the rank of `data`. **Required.**
bszmelcz marked this conversation as resolved.
Show resolved Hide resolved

**Outputs**:

* **1**: tensor with shape equal to `data` tensor of the type *T*.
* **1**: tensor with shape equal to `data` tensor of the type *T_NUMERIC*.

**Types**

* *T*: any numeric type.
* *T_NUMERIC*: any numeric type.

* *T_IND*: any supported integer types.

* *T_AXIS*: any supported integer types.

**Example**
**Examples**

*Example 1*

```xml
<layer ... type="ScatterUpdate">
<input>
<port id="0">
<port id="0"> <!-- data -->
<dim>1000</dim>
<dim>256</dim>
<dim>10</dim>
<dim>15</dim>
</port>
<port id="1">
<port id="1"> <!-- indices -->
<dim>125</dim>
<dim>20</dim>
</port>
<port id="2">
<port id="2"> <!-- udpates -->
<dim>1000</dim>
<dim>125</dim>
<dim>20</dim>
<dim>10</dim>
<dim>15</dim>
</port>
<port id="3"> <!-- value [1] -->
<dim>1</dim>
<port id="3"> <!-- axis -->
<dim>1</dim> <!-- value [1] -->
</port>
</input>
<output>
<port id="4" precision="FP32">
<port id="4" precision="FP32"> <!-- output -->
<dim>1000</dim>
<dim>256</dim>
<dim>10</dim>
Expand All @@ -91,3 +89,32 @@ The value can be in range `[-r, r - 1]` where `r` is the rank of `data`. **Requi
</output>
</layer>
```

*Example 2*

```xml
<layer ... type="ScatterUpdate">
<input>
<port id="0"> <!-- data -->
<dim>3</dim> <!-- {{-1.0f, 1.0f, -1.0f, 3.0f, 4.0f}, -->
<dim>5</dim> <!-- {-1.0f, 6.0f, -1.0f, 8.0f, 9.0f}, -->
</port> <!-- {-1.0f, 11.0f, 1.0f, 13.0f, 14.0f}} -->
<port id="1"> <!-- indices -->
<dim>2</dim> <!-- {0, 2} -->
</port>
<port id="2"> <!-- udpates -->
<dim>3</dim> <!-- {1.0f, 1.0f} -->
<dim>2</dim> <!-- {1.0f, 1.0f} -->
</port> <!-- {1.0f, 2.0f} -->
<port id="3"> <!-- axis -->
<dim>1</dim> <!-- {1} -->
</port>
</input>
<output>
<port id="4"> <!-- output -->
<dim>3</dim> <!-- {{1.0f, 1.0f, 1.0f, 3.0f, 4.0f}, -->
<dim>5</dim> <!-- {1.0f, 6.0f, 1.0f, 8.0f, 9.0f}, -->
</port> <!-- {1.0f, 11.0f, 2.0f, 13.0f, 14.0f}} -->
</output>
</layer>
```