Skip to content

Commit

Permalink
[PT FE] Fix issue with adding Result to mutated tensor (openvinotoolk…
Browse files Browse the repository at this point in the history
…it#20690)

* [PT FE] Fix issue with adding Result to mutated tensor

* Add test
  • Loading branch information
mvafin authored Oct 26, 2023
1 parent 52d3588 commit 5b8433f
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/frontends/pytorch/src/translate_session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,10 @@ std::shared_ptr<Model> TranslateSession::convert_pytorch_model(
auto mutated_tensor = tensor_map->at(tensor_id);
// empty external_tensor_map means this is main body of the model and we don't want to create
// additional outputs in that case.
if (mutated_tensor.get_target_inputs().empty() && !external_tensor_map.empty())
if (!external_tensor_map.empty()) {
OPENVINO_DEBUG << "Creating Result for mutated tensor " << tensor_id;
results.push_back(std::make_shared<v0::Result>(tensor_map->at(tensor_id)));
}
} else {
OPENVINO_DEBUG << "Mutated tensor with id " << tensor_id << " doesn't exist in inputs, skipping.";
}
Expand Down
45 changes: 45 additions & 0 deletions tests/layer_tests/pytorch_tests/test_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import os

import pytest
import numpy as np

from pytorch_layer_test_class import PytorchLayerTest


class TestLoopWithAlias(PytorchLayerTest):
def _prepare_input(self):
return (np.random.randn(*self.shape).astype(np.float32),)

def create_model(self, n):
import torch

class loop_alias_model(torch.nn.Module):
def __init__(self, n):
super(loop_alias_model, self).__init__()
self.n = n

def forward(self, x):
N = x.shape[1]
res = torch.zeros(1, self.n, dtype=torch.long)
d = torch.ones(1, N) * 1e10
f = torch.zeros(1, dtype=torch.long)

for i in range(self.n):
res[:, i] = f
_d = torch.sum((x - x[0, f, :]) ** 2, -1)
m = _d < d
d[m] = _d[m]
f = torch.max(d, -1)[1]
return res

return loop_alias_model(n), None, ["prim::Loop", "aten::copy_"]

@pytest.mark.parametrize("s,n", [([1, 1024, 3], 512), ([1, 512, 3], 128)])
@pytest.mark.nightly
@pytest.mark.precommit
def test_loop_alias(self, s, n, ie_device, precision, ir_version):
self.shape = s
self._test(*self.create_model(n), ie_device, precision,
ir_version, use_convert_model=True)

0 comments on commit 5b8433f

Please sign in to comment.