Skip to content

Commit

Permalink
[PT FE] Add prim::PythonOp (openvinotoolkit#15714)
Browse files Browse the repository at this point in the history
* Add PythonOp

* Fix deprecation & cleanup

* Apply suggestions from code review

* Fix dtype

* Apply suggestions from code review

Co-authored-by: Maxim Vafin <[email protected]>

* Update to new tensor names handling

* Fix negation

* Apply changes from code review

* Remove unnecesary imports

* Update src/frontends/pytorch/src/op/pythonop.cpp

Co-authored-by: Maxim Vafin <[email protected]>

---------

Co-authored-by: Maxim Vafin <[email protected]>
  • Loading branch information
2 people authored and andrei-cv committed Mar 21, 2023
1 parent 3ef62d3 commit 6ad9789
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 1 deletion.
13 changes: 12 additions & 1 deletion src/bindings/python/src/openvino/frontend/pytorch/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,10 @@ def get_output_transpose_order(self, index: int) -> list:
return []

def get_subgraph_size(self) -> int:
return len(self.get_subgraphs()) if hasattr(self.graph_element, "blocks") else 1
if isinstance(self.graph_element, torch.Node):
return len(self.get_subgraphs())
else:
return 1

def visit_subgraph(self, node_visitor) -> None:
# make sure topological order is satisfied
Expand All @@ -197,6 +200,14 @@ def visit_subgraph(self, node_visitor) -> None:
node_visitor(decoder)

def get_subgraphs(self) -> list:
if self.graph_element.kind() == "prim::PythonOp":
if "Subgraph" in self.graph_element.attributeNames():
assert isinstance(self.graph_element, torch.Node), "Graph element must be of type torch.Node."
return [getattr(self.graph_element, self.graph_element.kindOf("Subgraph"))("Subgraph")]
else:
# Attribute "Subgraph" is only available if Graph was created using tracing.
# TODO Find way to extract subgraph for scripted Graph.
return []
return list(self.graph_element.blocks())

def get_subgraph_decoder(self, index: int):
Expand Down
46 changes: 46 additions & 0 deletions src/frontends/pytorch/src/op/pythonop.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "pt_framework_node.hpp"
#include "translate_session.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

OutputVector translate_pythonop(NodeContext& context) {
auto decoder = context.get_decoder();
FRONT_END_OP_CONVERSION_CHECK(decoder->get_subgraph_size() == 1,
"PythonOp must have 1 subgraph to be able to translate it to OV.");
auto body = context.convert_subgraph(0);
auto session = context.get_session();

std::map<size_t, ParameterVector> inputs_map;
for (const auto& param : body->get_parameters()) {
auto tensor_idx = session->decode_tensor_name(param->output(0));
FRONT_END_OP_CONVERSION_CHECK(!inputs_map.count(tensor_idx),
"Multiple nodes with the same id are not allowed.");
inputs_map[tensor_idx] = {param};
}
for (const auto& input : inputs_map) {
auto external_output = context.get_input((int)input.first);
if (external_output.get_node()) {
input.second[0]->output(0).replace(external_output);
}
}

OutputVector outputs{};
for (auto result : body->get_results()) {
auto output = result->get_input_source_output(0);
outputs.push_back(context.mark_output(output));
}
return outputs;
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ OP_CONVERTER(translate_ones);
OP_CONVERTER(translate_ones_like);
OP_CONVERTER(translate_pad);
OP_CONVERTER(translate_pow);
OP_CONVERTER(translate_pythonop);
OP_CONVERTER(translate_reciprocal);
OP_CONVERTER(translate_relu6);
OP_CONVERTER(translate_remainder);
Expand Down Expand Up @@ -333,6 +334,7 @@ const std::map<std::string, PytorchCreatorFunction> get_supported_ops() {
{"prim::Loop", op::translate_loop},
{"prim::NumToTensor", op::skip_node}, // In openvino we already store number as tensor with shape []
{"prim::requires_grad", op::return_false_scalar},
{"prim::PythonOp", op::translate_pythonop},
{"prim::type", op::skip_node}, // Used with prim::device, pass PtFrameworkNode.
{"torchvision::nms", op::translate_nms},
{"torchvision::roi_align", op::translate_roi_align},
Expand Down
56 changes: 56 additions & 0 deletions tests/layer_tests/pytorch_tests/test_pythonop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import numpy as np
import pytest

from pytorch_layer_test_class import PytorchLayerTest


class TestPythonOp(PytorchLayerTest):
def _prepare_input(self):
return (np.random.randn(1, 3, 128, 128),)

def create_model(self):
import torch
from torch.autograd.function import Function

class _ExpF(Function):
@staticmethod
def forward(ctx, input_tensor):
exp = torch.exp(input_tensor)
ctx.save_for_backward(exp)
return exp

exp_f = _ExpF.apply

class prim_pythonop(torch.nn.Module):
def forward(self, input_tensor):
return exp_f(input_tensor)

ref_net = None

return prim_pythonop(), ref_net, "prim::PythonOp"

@pytest.mark.parametrize(
("use_trace"),
[
True,
pytest.param(
False,
marks=pytest.mark.xfail(
reason="Subgraph of prim::PythonOp cannot be retrived using if using scripting."
),
),
],
)
@pytest.mark.nightly
@pytest.mark.precommit
def test_pythonop(self, use_trace, ie_device, precision, ir_version):
self._test(
*self.create_model(),
ie_device,
precision,
ir_version,
trace_model=use_trace
)

0 comments on commit 6ad9789

Please sign in to comment.