Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PT FE] Add prim::PythonOp #15714

Merged
merged 13 commits into from
Mar 7, 2023
13 changes: 12 additions & 1 deletion src/bindings/python/src/openvino/frontend/pytorch/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import warnings
import torch
from torch.onnx import symbolic_helper


def get_type_from_py_type(value):
Expand Down Expand Up @@ -185,7 +186,10 @@ def get_output_transpose_order(self, index: int) -> list:
return []

def get_subgraph_size(self) -> int:
return len(self.get_subgraphs()) if hasattr(self.graph_element, "blocks") else 1
try:
return len(self.get_subgraphs())
mvafin marked this conversation as resolved.
Show resolved Hide resolved
except AttributeError:
return 1

def visit_subgraph(self, node_visitor) -> None:
# make sure topological order is satisfied
Expand All @@ -195,6 +199,13 @@ def visit_subgraph(self, node_visitor) -> None:
node_visitor(decoder)

def get_subgraphs(self) -> list:
if self.graph_element.kind() == "prim::PythonOp":
if "Subgraph" in self.graph_element.attributeNames():
return [symbolic_helper._node_get(self.graph_element, "Subgraph")]
else:
# Attribute "Subgraph" is only available if Graph was created using tracing.
# TODO Find way to extract subgraph for scripted Graph.
return []
return list(self.graph_element.blocks())

def get_subgraph_decoder(self, index: int):
Expand Down
51 changes: 51 additions & 0 deletions src/frontends/pytorch/src/op/pythonop.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/util/op_types.hpp"
#include "pt_framework_node.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

OutputVector translate_pythonop(NodeContext& context) {
auto decoder = context.get_decoder();
FRONT_END_OP_CONVERSION_CHECK(decoder->get_subgraph_size() == 1,
"PythonOp must have 1 subgraph to be able to translate it to OV.");
auto body = context.convert_subgraph(0);

std::map<size_t, ParameterVector> inputs_map;
for (auto param : body->get_parameters()) {
mmikolajcz marked this conversation as resolved.
Show resolved Hide resolved
auto name = param->get_output_tensor(0).get_any_name();
size_t input_idx = (size_t)std::stoll(name);
mvafin marked this conversation as resolved.
Show resolved Hide resolved
if (inputs_map.count(input_idx)) {
inputs_map[input_idx] = {param};
} else {
inputs_map[input_idx].push_back(param);
}
}
for (const auto& input : inputs_map) {
auto external_output = context.get_tensor_from_model(input.first);
if (external_output.get_node()) {
for (auto input_node : input.second) {
mvafin marked this conversation as resolved.
Show resolved Hide resolved
replace_node(input_node, context.get_input(input.first).get_node_shared_ptr());
}
}
}

OutputVector outputs{};
for (auto result : body->get_results()) {
auto output = result->get_input_source_output(0).get_node_shared_ptr();
context.mark_node(output);
outputs.push_back(output);
mmikolajcz marked this conversation as resolved.
Show resolved Hide resolved
}
return outputs;
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ OP_CONVERTER(translate_ones);
OP_CONVERTER(translate_ones_like);
OP_CONVERTER(translate_pad);
OP_CONVERTER(translate_pow);
OP_CONVERTER(translate_pythonop);
OP_CONVERTER(translate_reciprocal);
OP_CONVERTER(translate_relu6);
OP_CONVERTER(translate_remainder);
Expand Down Expand Up @@ -314,6 +315,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"prim::Loop", op::translate_loop},
{"prim::NumToTensor", op::skip_node}, // In openvino we already store number as tensor with shape []
{"prim::requires_grad", op::return_false_scalar},
{"prim::PythonOp", op::translate_pythonop},
{"torchvision::nms", op::translate_nms},
};
};
Expand Down
61 changes: 61 additions & 0 deletions tests/layer_tests/pytorch_tests/test_pythonop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import numpy as np
import pytest

from pytorch_layer_test_class import PytorchLayerTest


class TestPythonOp(PytorchLayerTest):
def _prepare_input(self):
return (np.random.randn(1, 3, 128, 128),)

def create_model(self):
import torch
from torch.autograd.function import Function

class _ExpF(Function):
@staticmethod
def forward(ctx, input_tensor):
exp = torch.exp(input_tensor)
ctx.save_for_backward(exp)
return exp

@staticmethod
def backward(ctx, output_grad):
mvafin marked this conversation as resolved.
Show resolved Hide resolved
(result,) = ctx.saved_tensors
return output_grad * result

exp_f = _ExpF.apply

class prim_pythonop(torch.nn.Module):
def forward(self, input_tensor):
return torch.add(exp_f(input_tensor), 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add test for DeformableConvolution that we have in models?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This op would add torchvision::deform_conv2d into unconverted OP, adding test would require to enable this op as well. I think it might be better to enable it in different PR to not mix it up, but if you want, I could enable this in current PR


ref_net = None

return prim_pythonop(), ref_net, "prim::PythonOp"

@pytest.mark.parametrize(
("use_trace"),
[
True,
pytest.param(
False,
marks=pytest.mark.xfail(
reason="Subgraph of prim::PythonOp cannot be retrived using if using scripting."
),
),
],
)
@pytest.mark.nightly
@pytest.mark.precommit
def test_pythonop(self, use_trace, ie_device, precision, ir_version):
self._test(
*self.create_model(),
ie_device,
precision,
ir_version,
trace_model=use_trace
)