Skip to content

Commit

Permalink
Improve documentation for .take_along_dim() with enhanced examples an…
Browse files Browse the repository at this point in the history
…d other content
  • Loading branch information
muhammedazhar committed Jan 15, 2025
1 parent d8d0914 commit e096b0e
Showing 1 changed file with 6 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ torch.take_along_dim(input, indices, dim)

## Example

### Basic Usage
Here is a basic usage example of `.take_along_dim()` in PyTorch to select elements along a specific dimension:

```python
import torch
Expand All @@ -57,14 +57,14 @@ result = torch.take_along_dim(input_tensor, indices, dim=1)
print(result)
```

### Output
The following will be the output of the above code:

```
tensor([[30, 20, 10],
[50, 40, 60]])
```

### Advanced Example: Multi-Dimensional Selection
Moreover, the function can also be used to select elements along a specific dimension in a multi-dimensional tensor. For instance, consider the following example:

```python
import torch
Expand All @@ -82,7 +82,7 @@ result = torch.take_along_dim(input_tensor, indices, dim=2)
print(result)
```

### Multi-Dimensional Selection Output
The output of the above code will be:

```
tensor([[[1, 2],
Expand All @@ -93,6 +93,8 @@ tensor([[[1, 2],

## Key Features

Here are some key features of the `.take_along_dim()` function:

1. Preserves tensor dimensionality during selection
2. Supports batch operations
3. Works with any number of dimensions
Expand Down

0 comments on commit e096b0e

Please sign in to comment.