Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PT FE] Add prim::PythonOp #15714

Merged
merged 13 commits into from
Mar 7, 2023
14 changes: 13 additions & 1 deletion src/bindings/python/src/openvino/frontend/pytorch/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

import warnings
import torch
from torch import _C as torch_C
from torch.onnx import symbolic_helper


def get_type_from_py_type(value):
Expand Down Expand Up @@ -191,7 +193,10 @@ def get_output_transpose_order(self, index: int) -> list:
return []

def get_subgraph_size(self) -> int:
return len(self.get_subgraphs()) if hasattr(self.graph_element, "blocks") else 1
if isinstance(self.graph_element, torch_C.Node):
return len(self.get_subgraphs())
mvafin marked this conversation as resolved.
Show resolved Hide resolved
else:
return 1

def visit_subgraph(self, node_visitor) -> None:
# make sure topological order is satisfied
Expand All @@ -201,6 +206,13 @@ def visit_subgraph(self, node_visitor) -> None:
node_visitor(decoder)

def get_subgraphs(self) -> list:
if self.graph_element.kind() == "prim::PythonOp":
if "Subgraph" in self.graph_element.attributeNames():
return [symbolic_helper._node_get(self.graph_element, "Subgraph")]
else:
# Attribute "Subgraph" is only available if Graph was created using tracing.
# TODO Find way to extract subgraph for scripted Graph.
return []
return list(self.graph_element.blocks())

def get_subgraph_decoder(self, index: int):
Expand Down
51 changes: 51 additions & 0 deletions src/frontends/pytorch/src/op/pythonop.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "pt_framework_node.hpp"
#include "translate_session.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

OutputVector translate_pythonop(NodeContext& context) {
auto decoder = context.get_decoder();
FRONT_END_OP_CONVERSION_CHECK(decoder->get_subgraph_size() == 1,
"PythonOp must have 1 subgraph to be able to translate it to OV.");
auto body = context.convert_subgraph(0);
auto session = context.get_session();

std::map<size_t, ParameterVector> inputs_map;
for (const auto& param : body->get_parameters()) {
auto tensor_idx = session->decode_tensor_name(param->output(0));
if (!inputs_map.count(tensor_idx)) {
inputs_map[tensor_idx] = {param};
} else {
inputs_map[tensor_idx].push_back(param);
mvafin marked this conversation as resolved.
Show resolved Hide resolved
}
}
for (const auto& input : inputs_map) {
auto external_output = context.get_tensor_from_model(input.first);
if (external_output.get_node()) {
for (auto input_node : input.second) {
mvafin marked this conversation as resolved.
Show resolved Hide resolved
replace_node(input_node, context.get_input((int)input.first).get_node_shared_ptr());
mvafin marked this conversation as resolved.
Show resolved Hide resolved
}
}
}

OutputVector outputs{};
for (auto result : body->get_results()) {
auto output = result->get_input_source_output(0).get_node_shared_ptr();
context.mark_node(output);
outputs.push_back(output);
mmikolajcz marked this conversation as resolved.
Show resolved Hide resolved
}
return outputs;
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ OP_CONVERTER(translate_ones);
OP_CONVERTER(translate_ones_like);
OP_CONVERTER(translate_pad);
OP_CONVERTER(translate_pow);
OP_CONVERTER(translate_pythonop);
OP_CONVERTER(translate_reciprocal);
OP_CONVERTER(translate_relu6);
OP_CONVERTER(translate_remainder);
Expand Down Expand Up @@ -316,6 +317,7 @@ const std::map<std::string, PytorchCreatorFunction> get_supported_ops() {
{"prim::Loop", op::translate_loop},
{"prim::NumToTensor", op::skip_node}, // In openvino we already store number as tensor with shape []
{"prim::requires_grad", op::return_false_scalar},
{"prim::PythonOp", op::translate_pythonop},
{"torchvision::nms", op::translate_nms},
};
};
Expand Down
56 changes: 56 additions & 0 deletions tests/layer_tests/pytorch_tests/test_pythonop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import numpy as np
import pytest

from pytorch_layer_test_class import PytorchLayerTest


class TestPythonOp(PytorchLayerTest):
def _prepare_input(self):
return (np.random.randn(1, 3, 128, 128),)

def create_model(self):
import torch
from torch.autograd.function import Function

class _ExpF(Function):
@staticmethod
def forward(ctx, input_tensor):
exp = torch.exp(input_tensor)
ctx.save_for_backward(exp)
return exp

exp_f = _ExpF.apply

class prim_pythonop(torch.nn.Module):
def forward(self, input_tensor):
return exp_f(input_tensor)

ref_net = None

return prim_pythonop(), ref_net, "prim::PythonOp"

@pytest.mark.parametrize(
("use_trace"),
[
True,
pytest.param(
False,
marks=pytest.mark.xfail(
reason="Subgraph of prim::PythonOp cannot be retrived using if using scripting."
),
),
],
)
@pytest.mark.nightly
@pytest.mark.precommit
def test_pythonop(self, use_trace, ie_device, precision, ir_version):
self._test(
*self.create_model(),
ie_device,
precision,
ir_version,
trace_model=use_trace
)