forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[PT FE] Add prim::PythonOp (openvinotoolkit#15714)
* Add PythonOp * Fix deprecation & cleanup * Apply suggestions from code review * Fix dtype * Apply suggestions from code review Co-authored-by: Maxim Vafin <[email protected]> * Update to new tensor names handling * Fix negation * Apply changes from code review * Remove unnecesary imports * Update src/frontends/pytorch/src/op/pythonop.cpp Co-authored-by: Maxim Vafin <[email protected]> --------- Co-authored-by: Maxim Vafin <[email protected]>
- Loading branch information
Showing
4 changed files
with
116 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
// Copyright (C) 2018-2023 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "openvino/frontend/pytorch/node_context.hpp" | ||
#include "pt_framework_node.hpp" | ||
#include "translate_session.hpp" | ||
|
||
namespace ov { | ||
namespace frontend { | ||
namespace pytorch { | ||
namespace op { | ||
|
||
OutputVector translate_pythonop(NodeContext& context) { | ||
auto decoder = context.get_decoder(); | ||
FRONT_END_OP_CONVERSION_CHECK(decoder->get_subgraph_size() == 1, | ||
"PythonOp must have 1 subgraph to be able to translate it to OV."); | ||
auto body = context.convert_subgraph(0); | ||
auto session = context.get_session(); | ||
|
||
std::map<size_t, ParameterVector> inputs_map; | ||
for (const auto& param : body->get_parameters()) { | ||
auto tensor_idx = session->decode_tensor_name(param->output(0)); | ||
FRONT_END_OP_CONVERSION_CHECK(!inputs_map.count(tensor_idx), | ||
"Multiple nodes with the same id are not allowed."); | ||
inputs_map[tensor_idx] = {param}; | ||
} | ||
for (const auto& input : inputs_map) { | ||
auto external_output = context.get_input((int)input.first); | ||
if (external_output.get_node()) { | ||
input.second[0]->output(0).replace(external_output); | ||
} | ||
} | ||
|
||
OutputVector outputs{}; | ||
for (auto result : body->get_results()) { | ||
auto output = result->get_input_source_output(0); | ||
outputs.push_back(context.mark_output(output)); | ||
} | ||
return outputs; | ||
}; | ||
|
||
} // namespace op | ||
} // namespace pytorch | ||
} // namespace frontend | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# Copyright (C) 2018-2023 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
from pytorch_layer_test_class import PytorchLayerTest | ||
|
||
|
||
class TestPythonOp(PytorchLayerTest): | ||
def _prepare_input(self): | ||
return (np.random.randn(1, 3, 128, 128),) | ||
|
||
def create_model(self): | ||
import torch | ||
from torch.autograd.function import Function | ||
|
||
class _ExpF(Function): | ||
@staticmethod | ||
def forward(ctx, input_tensor): | ||
exp = torch.exp(input_tensor) | ||
ctx.save_for_backward(exp) | ||
return exp | ||
|
||
exp_f = _ExpF.apply | ||
|
||
class prim_pythonop(torch.nn.Module): | ||
def forward(self, input_tensor): | ||
return exp_f(input_tensor) | ||
|
||
ref_net = None | ||
|
||
return prim_pythonop(), ref_net, "prim::PythonOp" | ||
|
||
@pytest.mark.parametrize( | ||
("use_trace"), | ||
[ | ||
True, | ||
pytest.param( | ||
False, | ||
marks=pytest.mark.xfail( | ||
reason="Subgraph of prim::PythonOp cannot be retrived using if using scripting." | ||
), | ||
), | ||
], | ||
) | ||
@pytest.mark.nightly | ||
@pytest.mark.precommit | ||
def test_pythonop(self, use_trace, ie_device, precision, ir_version): | ||
self._test( | ||
*self.create_model(), | ||
ie_device, | ||
precision, | ||
ir_version, | ||
trace_model=use_trace | ||
) |