Skip to content

Commit

Permalink
Add eltwise types resolving. Support big int constants. (#15415)
Browse files Browse the repository at this point in the history
* Add eltwise types resolving. Support big int constants.

* Update src/bindings/python/src/openvino/frontend/pytorch/decoder.py

* Small fix

* Fix some cases

* Add tests for add in different types

* Add tests for mul

* Add tests for sub and div

* Small fixes

* Return list handling (needed for empty lists)

* Add test for empty list

* Update src/frontends/pytorch/src/op/mul.cpp

Co-authored-by: Roman Kazantsev <[email protected]>

* Use refs instead of ptrs

* Apply suggestions from code review

Co-authored-by: Roman Kazantsev <[email protected]>

* Apply code review suggestions

* Fix code style

* Add more eltwise ops

---------

Co-authored-by: Roman Kazantsev <[email protected]>
  • Loading branch information
mvafin and rkazants authored Feb 2, 2023
1 parent 8051c2d commit 9264910
Show file tree
Hide file tree
Showing 15 changed files with 1,009 additions and 132 deletions.
78 changes: 37 additions & 41 deletions src/bindings/python/src/openvino/frontend/pytorch/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ def get_type_from_py_type(value):
if isinstance(value, float):
return OVType.f32
if isinstance(value, int):
return OVType.i32
# Python int is 64 bit, but we will convert it to int32 except cases when it can't fit in 32 bits
if torch.iinfo(torch.int).min <= value <= torch.iinfo(torch.int).max:
return OVType.i32
return OVType.i64
if isinstance(value, bool):
return OVType.boolean
return OVType.dynamic
Expand All @@ -27,13 +30,13 @@ def ivalue_to_constant(ivalue):
if ov_type.is_static():
return op.Constant(ov_type, Shape([]), [ivalue]).outputs()

if isinstance(ivalue, list):
if isinstance(ivalue, (list, tuple)):
assert len(ivalue) > 0, "Can't deduce type for empty list"
ov_type = get_type_from_py_type(ivalue[0])
assert ov_type.is_static(), "Can't deduce type for list"
return op.Constant(ov_type, Shape([len(ivalue)]), ivalue).outputs()

if ivalue.type() in pt_to_ov_type_map:
if isinstance(ivalue, torch.Tensor) and ivalue.type() in pt_to_ov_type_map:
try:
ovshape = PartialShape(ivalue.size())
ovtype = pt_to_ov_type_map[ivalue.type()]
Expand All @@ -46,6 +49,7 @@ def ivalue_to_constant(ivalue):
ovshape = PartialShape(nvalues.shape)
ov_const = op.Constant(ovtype, ovshape.get_shape(), nvalues.flatten().tolist())
return ov_const.outputs()
return None


def get_value_from_getattr(getattr_node, self_module):
Expand All @@ -69,25 +73,22 @@ def get_value_from_getattr(getattr_node, self_module):
pt_to_ov_type_map = {
"float": OVType.f32,
"int": OVType.i32,
"bool": OVType.boolean,
"torch.float16": OVType.f16,
"torch.float32": OVType.f32,
"torch.float64": OVType.f64,
"torch.uint8": OVType.u8,
"torch.int8": OVType.i8,
"torch.int32": OVType.i32,
"torch.bool": OVType.boolean,
"torch.int64": OVType.i64,
"torch.bool": OVType.boolean,
"torch.DoubleTensor": OVType.f64,
"torch.FloatTensor": OVType.f32,
"torch.IntTensor": OVType.i32,
"torch.LongTensor": OVType.i64,
"torch.BoolTensor": OVType.boolean,
}

pt_to_py_type_map = {
"float": "float",
"int": "int",
"torch.float32": "float",
"torch.int32": "int",
"torch.int64": "int",
"torch.bool": "bool",
}

np_to_ov_type_map = {
"float32": OVType.f32,
"int32": OVType.i32,
Expand All @@ -106,7 +107,7 @@ def __init__(self, pt_module, graph_element=None):
self.graph_element = graph_element
self.pt_module = pt_module

def inputs(self):
def inputs(self) -> list:
return [x.unique() for x in self.graph_element.inputs()]

def get_input(self, index: int):
Expand Down Expand Up @@ -150,7 +151,7 @@ def _get_known_type_for_value(self, pt_type):
# Not yet recognized
return OVAny(OVType.dynamic)

def get_shape_for_value(self, value):
def get_shape_for_value(self, value: torch.Value):
if value.isCompleteTensor():
ps = PartialShape(value.type().sizes())
return ps
Expand All @@ -161,7 +162,7 @@ def get_shape_for_value(self, value):
pass
return PartialShape.dynamic()

def get_type_for_value(self, value):
def get_type_for_value(self, value: torch.Value):
full_type = self._get_known_type_for_value(value.type())
return full_type

Expand All @@ -184,46 +185,46 @@ def get_output_transpose_order(self, index: int) -> list:
def get_subgraph_size(self) -> int:
return len(self.get_subgraphs()) if hasattr(self.graph_element, "blocks") else 1

def visit_subgraph(self, node_visitor):
def visit_subgraph(self, node_visitor) -> None:
# make sure topological order is satisfied
for node in self.graph_element.nodes():
decoder = TorchScriptPythonDecoder(self.pt_module, node)
self.m_decoders.append(decoder)
node_visitor(decoder)

def get_subgraphs(self):
def get_subgraphs(self) -> list:
return list(self.graph_element.blocks())

def get_subgraph_decoder(self, index):
def get_subgraph_decoder(self, index: int):
decoder = TorchScriptPythonDecoder(self.pt_module, self.get_subgraphs()[index])
self.m_decoders.append(decoder)
return decoder

def get_op_type(self):
def get_op_type(self) -> str:
return self.graph_element.kind()

def get_schema(self):
def get_schema(self) -> str:
return self.graph_element.schema()

def outputs(self):
def outputs(self) -> list:
return [x.unique() for x in self.graph_element.outputs()]

def _raw_outputs(self):
def _raw_outputs(self) -> list:
return list(self.graph_element.outputs())

def _raw_output(self, index):
def _raw_output(self, index: int):
return self._raw_outputs()[index]

def _raw_inputs(self):
def _raw_inputs(self) -> list:
return list(self.graph_element.inputs())

def _raw_input(self, index):
def _raw_input(self, index: int):
return self._raw_inputs()[index]

def num_of_outputs(self):
return len(self.outputs())

def output(self, index):
def output(self, index: int):
return self.outputs()[index]

def mark_node(self, node):
Expand All @@ -232,7 +233,7 @@ def mark_node(self, node):
def try_decode_get_attr(self):
pt_value = get_value_from_getattr(self.graph_element, self.pt_module)
assert pt_value is not None, "Couldn't retrieve value from prim::GetAttr"
if not isinstance(pt_value, torch.jit.ScriptModule) or isinstance(pt_value, torch.jit.TracedModule):
if not isinstance(pt_value, (torch.jit.ScriptModule, torch.jit.TracedModule)):
return ivalue_to_constant(pt_value)
else:
return []
Expand All @@ -244,17 +245,10 @@ def as_constant(self):

pt_type = pt_value.type()
if isinstance(pt_type, torch.TensorType):
return self.as_constant_tensor(pt_value)
return self._as_constant_tensor(pt_value)
if isinstance(pt_type, torch.ListType):
return self.as_constant_list(pt_value)
if str(pt_type) in ["torch.int32", "int"]:
return op.Constant(OVType.i32, Shape([]), [pt_value.toIValue()]).outputs()
if str(pt_type) in ["torch.float", "torch.FloatType", "float"]:
return op.Constant(OVType.f32, Shape([]), [pt_value.toIValue()]).outputs()
if str(pt_type) in ["torch.bool", "bool"]:
return op.Constant(OVType.boolean, Shape([]), [pt_value.toIValue()]).outputs()

return None
return self._as_constant_list(pt_value)
return ivalue_to_constant(pt_value.toIValue())

def as_string(self):
if not self.get_op_type() == "prim::Constant":
Expand All @@ -265,7 +259,8 @@ def as_string(self):
return pt_value.toIValue()
return None

def as_constant_tensor(self, pt_value):
@staticmethod
def _as_constant_tensor(pt_value: torch.Value):
ivalue = pt_value.toIValue()
if pt_value.isCompleteTensor():
try:
Expand Down Expand Up @@ -295,7 +290,8 @@ def as_constant_tensor(self, pt_value):
return ivalue_to_constant(ivalue)
return None

def as_constant_list(self, pt_value):
@staticmethod
def _as_constant_list(pt_value: torch.Value):
# For now it is treat a list as a 1D tensor; it is required by converters to avoid need to massively
# rewrite them in that part where constant attributes are queried
pt_element_type = str(pt_value.type().getElementType())
Expand All @@ -308,7 +304,7 @@ def as_constant_list(self, pt_value):
ov_const = op.Constant(ovtype, ovshape.get_shape(), ivalue)
return ov_const.outputs()

def input_is_none(self, index):
def input_is_none(self, index: int) -> bool:
if index >= len(self.inputs()) or self._raw_input(index) is None:
return True
else:
Expand Down
13 changes: 9 additions & 4 deletions src/frontends/pytorch/src/op/add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/op/add.hpp"

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/multiply.hpp"
#include "utils.hpp"

namespace ov {
Expand All @@ -12,12 +15,14 @@ namespace pytorch {
namespace op {

OutputVector translate_add(NodeContext& context) {
auto lhs = context.get_input(0);
auto rhs = context.get_input(1);
align_eltwise_input_types(context, lhs, rhs);
if (!context.input_is_none(2)) {
auto converted_alpha = std::make_shared<opset10::ConvertLike>(context.get_input(2), rhs);
rhs = std::make_shared<opset10::Multiply>(converted_alpha, rhs);
auto converted_alpha = context.mark_node(std::make_shared<ov::op::v1::ConvertLike>(context.get_input(2), rhs));
rhs = context.mark_node(std::make_shared<ov::op::v1::Multiply>(converted_alpha, rhs));
}
return {context.mark_node(std::make_shared<opset10::Add>(context.get_input(0), rhs))};
return {context.mark_node(std::make_shared<ov::op::v1::Add>(lhs, rhs))};
};

} // namespace op
Expand Down
37 changes: 24 additions & 13 deletions src/frontends/pytorch/src/op/div.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/floor.hpp"
#include "utils.hpp"

using namespace ov::op;

namespace ov {
namespace frontend {
namespace pytorch {
Expand All @@ -14,21 +19,27 @@ namespace op {
OutputVector translate_div(NodeContext& context) {
auto x = context.get_input(0);
auto y = context.get_input(1);
auto res = context.mark_node(std::make_shared<opset10::Divide>(x, y, true));
std::string rounding_mode = "";
if (!context.input_is_none(2)) {
auto rounding_mode = context.const_input<std::string>(2);
if (rounding_mode == "floor") {
res = context.mark_node(std::make_shared<opset10::Floor>(res));
} else if (rounding_mode == "trunc") {
const auto convert = context.mark_node(std::make_shared<opset10::Convert>(res, element::i64));
res = context.mark_node(std::make_shared<opset10::ConvertLike>(convert, x));
} else {
FRONT_END_OP_CONVERSION_CHECK(false,
"Openvino Pytorch Frontend doesn't support rounding mode ",
rounding_mode,
" for aten::div");
rounding_mode = context.const_input<std::string>(2);
}
if (rounding_mode.empty()) {
// if no rounding mode and both inputs are ints cast BOTH to fp32
const auto x_dtype = x.get_element_type();
const auto y_dtype = y.get_element_type();
if (x_dtype.is_static() && x_dtype.is_integral() && y_dtype.is_static() && y_dtype.is_integral()) {
x = context.mark_node(std::make_shared<v0::Convert>(x, element::f32));
y = context.mark_node(std::make_shared<v0::Convert>(y, element::f32));
}
}
align_eltwise_input_types(context, x, y, true);
auto res = context.mark_node(std::make_shared<v1::Divide>(x, y, true));
if (rounding_mode == "floor") {
res = context.mark_node(std::make_shared<v0::Floor>(res));
} else if (rounding_mode == "trunc") {
const auto convert = context.mark_node(std::make_shared<v0::Convert>(res, element::i64));
res = context.mark_node(std::make_shared<v1::ConvertLike>(convert, x));
}
return {res};
};

Expand Down
25 changes: 25 additions & 0 deletions src/frontends/pytorch/src/op/pow.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/power.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

OutputVector translate_pow(NodeContext& context) {
num_inputs_check(context, 1, 2);
auto lhs = context.get_input(0);
auto rhs = context.get_input(1);
align_eltwise_input_types(context, lhs, rhs, true);
return {context.mark_node(std::make_shared<ov::op::v1::Power>(lhs, rhs))};
}

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
13 changes: 9 additions & 4 deletions src/frontends/pytorch/src/op/sub.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/subtract.hpp"
#include "utils.hpp"

using namespace ov::op;

namespace ov {
namespace frontend {
namespace pytorch {
Expand All @@ -14,13 +18,14 @@ namespace op {
OutputVector translate_sub(NodeContext& context) {
auto x = context.get_input(0);
auto y = context.get_input(1);
align_eltwise_input_types(context, x, y);
// default alpha is 1 so no need to multiply if alpha is not provided
if (!context.input_is_none(2)) {
auto alpha = context.get_input(2);
auto casted_alpha = context.mark_node(std::make_shared<opset10::ConvertLike>(alpha, y));
y = context.mark_node(std::make_shared<opset10::Multiply>(casted_alpha, y));
auto casted_alpha = context.mark_node(std::make_shared<v1::ConvertLike>(alpha, y));
y = context.mark_node(std::make_shared<v1::Multiply>(casted_alpha, y));
}
return {context.mark_node(std::make_shared<opset10::Subtract>(x, y))};
return {context.mark_node(std::make_shared<v1::Subtract>(x, y))};
};

} // namespace op
Expand Down
Loading

0 comments on commit 9264910

Please sign in to comment.