Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Gather-7 specification #5441

Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions docs/ops/movement/Gather_7.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,19 @@ TensorFlow\* [Gather](https://www.tensorflow.org/api_docs/python/tf/gather) oper

**Detailed description**

output[p_0, p_1, ..., p_{axis-1}, p_axis, ..., p_{axis + k}, ...] =
data[p_0, p_1, ..., p_{axis-1}, indices[p_0, p_1, ..., p_{b-1}, p_b, ..., p_{axis}, j], ...]
output[p_0, p_1, ..., p_{axis-1}, i_b, ..., i_{M-1}, p_{axis+1}, ..., p_{N-1}] =
data[p_0, p_1, ..., p_{axis-1}, indices[p_0, p_1, ..., p_{b-1}, i_b, ..., i_{M-1}], p_{axis+1}, ..., p_{N-1}]

Where `data`, `indices` and `axis` are tensors from first, second and third inputs correspondingly, and `b` is
the number of batch dimensions.
Where `data`, `indices` and `axis` are tensors from first, second and third inputs correspondingly, `b` is
the number of batch dimensions. `N` and `M` are numbers of dimensions of `data` and `indices` tensors, respectively.

**Attributes**:
* *batch_dims*
* **Description**: *batch_dims* (also denoted as `b`) is a leading number of dimensions of `data` tensor and `indices`
representing the batches, and *Gather* starts to gather from the `b` dimension. It requires the first `b`
dimensions in `data` and `indices` tensors to be equal.
* **Range of values**: `[0; min(data.rank, indices.rank))` and `batch_dims <= axis`
dimensions in `data` and `indices` tensors to be equal. If *batch_dims* is less than zero used normalized value
`batch_dims = indices.rank + batch_dims`.
* **Range of values**: `[-min(data.rank, indices.rank); min(data.rank, indices.rank))` and `batch_dims <= axis`
pavel-esir marked this conversation as resolved.
Show resolved Hide resolved
* **Type**: *T_AXIS*
* **Default value**: 0
* **Required**: *no*
Expand Down Expand Up @@ -112,6 +113,24 @@ output = [[[[ 5, 6, 7, 8],
output_shape = (2, 1, 3, 4)
```

Example 5 with negative *batch_dims* value:
```
batch_dims = -1 <-- normalized value will be indices.rank + batch_dims = 2 - 1 = 1
axis = 1

indices = [[0, 0, 4], <-- this is applied to the first batch
[4, 0, 0]] <-- this is applied to the second batch
indices_shape = (2, 3)

data = [[1, 2, 3, 4, 5], <-- the first batch
[6, 7, 8, 9, 10]] <-- the second batch
data_shape = (2, 5)

output = [[ 1, 1, 5],
[10, 6, 6]]
output_shape = (2, 3)
```

**Inputs**

* **1**: `data` tensor of type *T* with arbitrary data. **Required**.
Expand Down