From 84896151c0692d9105450368c4fac5d1d7c0eadd Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Fri, 21 Aug 2020 18:54:26 -0700 Subject: [PATCH] [Frontend][Pytorch]Add Pytorch advanced indexing (#6318) * Add Pytorch advanced indexing * Minor fix for test * Fix for cuda --- python/tvm/relay/frontend/pytorch.py | 53 +++++++++++++++++-- tests/python/frontend/pytorch/test_forward.py | 24 ++++++++- 2 files changed, 72 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index b75f3f909b93..723740377cde 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -274,16 +274,18 @@ def _impl(inputs, input_types): end[dim] = min(end[dim], int(inputs[3])) else: if isinstance(inputs[3], _expr.Call): - end[dim] = np.asscalar(_infer_value(inputs[3], {}).asnumpy().astype(np.int)) + target_end = np.asscalar(_infer_value(inputs[3], {}).asnumpy().astype(np.int)) else: - end[dim] = inputs[3] + target_end = inputs[3] + + end[dim] = min(end[dim], target_end) strides.append(int(inputs[4])) return _op.transform.strided_slice(data, begin=_expr.const(begin), end=_expr.const(end), strides=_expr.const(strides), - slice_mode="size") + slice_mode="end") return _impl def _split(): @@ -1759,6 +1761,50 @@ def _impl(inputs, input_types): return _impl +def _index(): + def _impl(inputs, input_types): + data = inputs[0] + indices = [] + raw_indices = [] + max_indices_len = -1 + for index in inputs[1]: + if not isinstance(index, _expr.Constant): + try: + index = _expr.const(_infer_value(index, {})) + except Exception: + raise RuntimeError("Only supports constant indices for " + "pytorch advanced indexing ") + raw_indices.append(index) + cindex_len = index.data.shape[0] + if cindex_len > max_indices_len: + max_indices_len = cindex_len + + for index in raw_indices: + cnp = index.data.asnumpy() + cindex_len = cnp.shape[0] + if cindex_len < max_indices_len: + cnp = np.tile(cnp, max_indices_len // cindex_len) + indices.append(cnp) + + ret = [] + slice_map = {} + for i in range(indices[0].shape[0]): + tmp = data + current_indices = [] + for index in indices: + current_indices.append(index[i]) + index_key = tuple(current_indices) + if index_key in slice_map: + tmp = slice_map[index_key] + else: + tmp = _op.take(tmp, _expr.const(index[i]), axis=0) + slice_map[index_key] = tmp + ret.append(_op.expand_dims(tmp, axis=0)) + + return _op.concatenate(ret, axis=0) + return _impl + + def _meshgrid(): def _impl(inputs, input_types): data = inputs[0] @@ -2064,6 +2110,7 @@ def _get_convert_map(prelude): "aten::type_as" : _type_as(), "aten::gather" : _gather(), "aten::index_select" : _select(), + "aten::index" : _index(), } return convert_map diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index e5c963454450..ab0a4b03cafa 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1202,13 +1202,13 @@ def forward(self, *args): class Slice2(Module): def forward(self, *args): - return args[0][0, :, :, :] + return args[0][0, :, :-3, :] class Slice3(Module): def forward(self, *args): x0 = torch.tensor(2) - torch.tensor(1) x1 = torch.tensor(3) + torch.tensor(1) - return args[0][:, x0:, :x1, :] + return args[0][:, x0:, 1:x1, :] input_data = torch.rand(input_shape).float() verify_model(Slice1().float().eval(), input_data=input_data) @@ -2620,6 +2620,25 @@ def forward(self, *args): verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2]) +def test_forward_index(): + torch.set_grad_enabled(False) + input_shape = [3, 4, 5, 6] + + class Index0(Module): + def forward(self, x): + return x[[0, 1], [0, 2], :2, 4] + + input_data = torch.rand(input_shape).float() + verify_model(Index0().eval(), input_data=input_data) + + class Index1(Module): + def forward(self, x): + return x[[0], [1, 2, 3, 0], [3, 1, 2, 2], [4, 2, 1, 0]] + + input_data = torch.rand(input_shape).float() + verify_model(Index1().eval(), input_data=input_data) + + def test_forward_pretrained_bert_base_uncased(): ###################################################################### # This is an example how to run BERT models using TVM @@ -2859,6 +2878,7 @@ def test_forward_pretrained_bert_base_uncased(): test_adaptive_pool3d() test_conv3d() test_conv3d_transpose() + test_forward_index() # Model tests test_resnet18()