Skip to content

Commit

Permalink
Fix formatting inconsistencies in take_along_dim.md examples
Browse files Browse the repository at this point in the history
  • Loading branch information
muhammedazhar committed Jan 15, 2025
1 parent 9d29933 commit d8d0914
Showing 1 changed file with 5 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ torch.take_along_dim(input, indices, dim)
import torch

# Create a source tensor
input_tensor = torch.tensor([[10, 20, 30],
input_tensor = torch.tensor([[10, 20, 30],
[40, 50, 60]])

# Define indices for selection
indices = torch.tensor([[2, 1, 0],
indices = torch.tensor([[2, 1, 0],
[1, 0, 2]])

# Select elements along dimension 1
Expand All @@ -70,11 +70,11 @@ tensor([[30, 20, 10],
import torch

# Create a 3D tensor
input_tensor = torch.tensor([[[1, 2], [3, 4]],
input_tensor = torch.tensor([[[1, 2], [3, 4]],
[[5, 6], [7, 8]]])

# Define indices for selection
indices = torch.tensor([[[0, 1], [1, 0]],
indices = torch.tensor([[[0, 1], [1, 0]],
[[0, 0], [1, 1]]])

# Select elements along the last dimension
Expand Down Expand Up @@ -118,7 +118,7 @@ tensor([[[1, 2],
import torch
# Create a tensor
tensor = torch.tensor([[10, 20, 30],
tensor = torch.tensor([[10, 20, 30],
[40, 50, 60]])
# Create indices for top-2 values along dim=1
Expand Down

0 comments on commit d8d0914

Please sign in to comment.