Skip to content

Commit

Permalink
[Torch] Fix advanced indexing with NoneType index
Browse files Browse the repository at this point in the history
  • Loading branch information
padreofthegame committed Jan 23, 2023
1 parent b77d24c commit 68f98e9
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 5 deletions.
40 changes: 35 additions & 5 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2330,19 +2330,49 @@ def one_hot(self, inputs, input_types):

def index(self, inputs, input_types):
data = inputs[0]
data_shape = self.infer_type(data).shape

axes_adv_idx = [i for i, v in enumerate(inputs[1]) if v is not None]
axes_rest = [i for i in range(len(data_shape)) if i not in axes_adv_idx]

# check if the adv_index axes are consecutive
# if consecutive, result must be transposed again at the end
consecutive = True
for curr, nxt in zip(axes_adv_idx[:-1], axes_adv_idx[1:]):
if nxt - curr != 1:
consecutive = False
break

indices_list = []
axes_order = axes_adv_idx + axes_rest

for indices in inputs[1]:
if self.infer_type(indices).dtype == "bool":
for i in axes_adv_idx:
inp = inputs[1][i]
if self.infer_type(inp).dtype == "bool":
# adv_index does not support a mask as the index tensor (it will treat 0/1 as
# an index rather than a flag).
# So we use argwhere to turn the mask into indices, which will also take care
# of the dynamism in the indexing by mask.
indices_list.append(_op.squeeze(_op.transform.argwhere(indices), axis=[1]))
indices_list.append(_op.squeeze(_op.transform.argwhere(inp), axis=[1]))
else:
indices_list.append(indices)
indices_list.append(inp)

data_after_adv_index = _op.adv_index([_op.transpose(data, axes=axes_order)] + indices_list)

return _op.adv_index([data] + indices_list)
if consecutive:
num_dims = len(self.infer_type(data_after_adv_index).shape)
num_new_dims = num_dims - len(axes_rest)

axes_final_order = list(range(num_dims))
axes_final_order = (
axes_final_order[num_new_dims : num_new_dims + axes_adv_idx[0]]
+ axes_final_order[:num_new_dims]
+ axes_final_order[num_new_dims + axes_adv_idx[0] :]
)

return _op.transpose(data_after_adv_index, axes=axes_final_order)
else:
return data_after_adv_index

def meshgrid(self, inputs, input_types):
data = inputs[0]
Expand Down
35 changes: 35 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4034,6 +4034,41 @@ def forward(self, x):
input_data = torch.rand(input_shape).float()
verify_model(Index1().eval(), input_data=input_data)

class Index2(Module):
def forward(self, x):
return x[None, [2, 2]]

input_data = torch.rand(input_shape).float()
verify_model(Index2().eval(), input_data=input_data)

class Index3(Module):
def forward(self, x):
return x[None, [0, 1, 2], 1, [2, 3, 4]]

input_data = torch.rand(input_shape).float()
verify_model(Index3().eval(), input_data=input_data)

class Index4(Module):
def forward(self, x):
return x[None, [0, 0], None, np.array([[0], [1], [2]]), None]

input_data = torch.rand(input_shape).float()
verify_model(Index4().eval(), input_data=input_data)

class Index5(Module):
def forward(self, x):
return x[None, None, [0, 0], np.array([[0], [1], [2]]), None]

input_data = torch.rand(input_shape).float()
verify_model(Index5().eval(), input_data=input_data)

class Index6(Module):
def forward(self, x):
return x[None, 1, None, [1, 2, 3]]

input_data = torch.rand(input_shape).float()
verify_model(Index6().eval(), input_data=input_data)

def test_fn_bool_mask():
return lambda data, mask: data[0, mask]

Expand Down

0 comments on commit 68f98e9

Please sign in to comment.