Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add eltwise types resolving. Support big int constants. #15415

Merged
merged 19 commits into from
Feb 2, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 37 additions & 41 deletions src/bindings/python/src/openvino/frontend/pytorch/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ def get_type_from_py_type(value):
if isinstance(value, float):
return OVType.f32
if isinstance(value, int):
return OVType.i32
# Python int is 64 bit, but we will convert it to int32 except cases when it can't fit in 32 bits
if torch.iinfo(torch.int).min <= value <= torch.iinfo(torch.int).max:
return OVType.i32
return OVType.i64
mvafin marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(value, bool):
return OVType.boolean
return OVType.dynamic
Expand All @@ -27,13 +30,13 @@ def ivalue_to_constant(ivalue):
if ov_type.is_static():
return op.Constant(ov_type, Shape([]), [ivalue]).outputs()

if isinstance(ivalue, list):
if isinstance(ivalue, (list, tuple)):
assert len(ivalue) > 0, "Can't deduce type for empty list"
ov_type = get_type_from_py_type(ivalue[0])
assert ov_type.is_static(), "Can't deduce type for list"
return op.Constant(ov_type, Shape([len(ivalue)]), ivalue).outputs()

if ivalue.type() in pt_to_ov_type_map:
if isinstance(ivalue, torch.Tensor) and ivalue.type() in pt_to_ov_type_map:
try:
ovshape = PartialShape(ivalue.size())
ovtype = pt_to_ov_type_map[ivalue.type()]
Expand All @@ -46,6 +49,7 @@ def ivalue_to_constant(ivalue):
ovshape = PartialShape(nvalues.shape)
ov_const = op.Constant(ovtype, ovshape.get_shape(), nvalues.flatten().tolist())
return ov_const.outputs()
return None


def get_value_from_getattr(getattr_node, self_module):
Expand All @@ -69,25 +73,22 @@ def get_value_from_getattr(getattr_node, self_module):
pt_to_ov_type_map = {
"float": OVType.f32,
"int": OVType.i32,
"bool": OVType.boolean,
"torch.float16": OVType.f16,
"torch.float32": OVType.f32,
"torch.float64": OVType.f64,
"torch.uint8": OVType.u8,
"torch.int8": OVType.i8,
"torch.int32": OVType.i32,
"torch.bool": OVType.boolean,
"torch.int64": OVType.i64,
"torch.bool": OVType.boolean,
"torch.DoubleTensor": OVType.f64,
"torch.FloatTensor": OVType.f32,
"torch.IntTensor": OVType.i32,
"torch.LongTensor": OVType.i64,
"torch.BoolTensor": OVType.boolean,
}

pt_to_py_type_map = {
"float": "float",
"int": "int",
"torch.float32": "float",
"torch.int32": "int",
"torch.int64": "int",
"torch.bool": "bool",
}

np_to_ov_type_map = {
"float32": OVType.f32,
"int32": OVType.i32,
Expand All @@ -106,7 +107,7 @@ def __init__(self, pt_module, graph_element=None):
self.graph_element = graph_element
self.pt_module = pt_module

def inputs(self):
def inputs(self) -> list:
return [x.unique() for x in self.graph_element.inputs()]

def get_input(self, index: int):
Expand Down Expand Up @@ -150,7 +151,7 @@ def _get_known_type_for_value(self, pt_type):
# Not yet recognized
return OVAny(OVType.dynamic)

def get_shape_for_value(self, value):
def get_shape_for_value(self, value: torch.Value):
if value.isCompleteTensor():
ps = PartialShape(value.type().sizes())
return ps
Expand All @@ -161,7 +162,7 @@ def get_shape_for_value(self, value):
pass
return PartialShape.dynamic()

def get_type_for_value(self, value):
def get_type_for_value(self, value: torch.Value):
full_type = self._get_known_type_for_value(value.type())
return full_type

Expand All @@ -184,46 +185,46 @@ def get_output_transpose_order(self, index: int) -> list:
def get_subgraph_size(self) -> int:
return len(self.get_subgraphs()) if hasattr(self.graph_element, "blocks") else 1

def visit_subgraph(self, node_visitor):
def visit_subgraph(self, node_visitor) -> None:
# make sure topological order is satisfied
for node in self.graph_element.nodes():
decoder = TorchScriptPythonDecoder(self.pt_module, node)
self.m_decoders.append(decoder)
node_visitor(decoder)

def get_subgraphs(self):
def get_subgraphs(self) -> list:
return list(self.graph_element.blocks())

def get_subgraph_decoder(self, index):
def get_subgraph_decoder(self, index: int):
decoder = TorchScriptPythonDecoder(self.pt_module, self.get_subgraphs()[index])
self.m_decoders.append(decoder)
return decoder

def get_op_type(self):
def get_op_type(self) -> str:
return self.graph_element.kind()

def get_schema(self):
def get_schema(self) -> str:
return self.graph_element.schema()

def outputs(self):
def outputs(self) -> list:
return [x.unique() for x in self.graph_element.outputs()]

def _raw_outputs(self):
def _raw_outputs(self) -> list:
return list(self.graph_element.outputs())

def _raw_output(self, index):
def _raw_output(self, index: int):
return self._raw_outputs()[index]

def _raw_inputs(self):
def _raw_inputs(self) -> list:
return list(self.graph_element.inputs())

def _raw_input(self, index):
def _raw_input(self, index: int):
return self._raw_inputs()[index]

def num_of_outputs(self):
return len(self.outputs())

def output(self, index):
def output(self, index: int):
return self.outputs()[index]

def mark_node(self, node):
Expand All @@ -232,7 +233,7 @@ def mark_node(self, node):
def try_decode_get_attr(self):
pt_value = get_value_from_getattr(self.graph_element, self.pt_module)
assert pt_value is not None, "Couldn't retrieve value from prim::GetAttr"
if not isinstance(pt_value, torch.jit.ScriptModule) or isinstance(pt_value, torch.jit.TracedModule):
if not isinstance(pt_value, (torch.jit.ScriptModule, torch.jit.TracedModule)):
return ivalue_to_constant(pt_value)
else:
return []
Expand All @@ -244,17 +245,10 @@ def as_constant(self):

pt_type = pt_value.type()
if isinstance(pt_type, torch.TensorType):
return self.as_constant_tensor(pt_value)
return self._as_constant_tensor(pt_value)
if isinstance(pt_type, torch.ListType):
return self.as_constant_list(pt_value)
if str(pt_type) in ["torch.int32", "int"]:
return op.Constant(OVType.i32, Shape([]), [pt_value.toIValue()]).outputs()
if str(pt_type) in ["torch.float", "torch.FloatType", "float"]:
return op.Constant(OVType.f32, Shape([]), [pt_value.toIValue()]).outputs()
if str(pt_type) in ["torch.bool", "bool"]:
return op.Constant(OVType.boolean, Shape([]), [pt_value.toIValue()]).outputs()

return None
return self._as_constant_list(pt_value)
return ivalue_to_constant(pt_value.toIValue())

def as_string(self):
if not self.get_op_type() == "prim::Constant":
Expand All @@ -265,7 +259,8 @@ def as_string(self):
return pt_value.toIValue()
return None

def as_constant_tensor(self, pt_value):
@staticmethod
def _as_constant_tensor(pt_value: torch.Value):
ivalue = pt_value.toIValue()
if pt_value.isCompleteTensor():
try:
Expand Down Expand Up @@ -295,7 +290,8 @@ def as_constant_tensor(self, pt_value):
return ivalue_to_constant(ivalue)
return None

def as_constant_list(self, pt_value):
@staticmethod
def _as_constant_list(pt_value: torch.Value):
# For now it is treat a list as a 1D tensor; it is required by converters to avoid need to massively
# rewrite them in that part where constant attributes are queried
pt_element_type = str(pt_value.type().getElementType())
Expand All @@ -308,7 +304,7 @@ def as_constant_list(self, pt_value):
ov_const = op.Constant(ovtype, ovshape.get_shape(), ivalue)
return ov_const.outputs()

def input_is_none(self, index):
def input_is_none(self, index: int) -> bool:
if index >= len(self.inputs()) or self._raw_input(index) is None:
return True
else:
Expand Down
13 changes: 9 additions & 4 deletions src/frontends/pytorch/src/op/add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/op/add.hpp"

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/multiply.hpp"
#include "utils.hpp"

namespace ov {
Expand All @@ -12,12 +15,14 @@ namespace pytorch {
namespace op {

OutputVector translate_add(NodeContext& context) {
auto lhs = context.get_input(0);
mvafin marked this conversation as resolved.
Show resolved Hide resolved
auto rhs = context.get_input(1);
align_eltwise_input_types(context, lhs, rhs);
if (!context.input_is_none(2)) {
auto converted_alpha = std::make_shared<opset10::ConvertLike>(context.get_input(2), rhs);
rhs = std::make_shared<opset10::Multiply>(converted_alpha, rhs);
auto converted_alpha = context.mark_node(std::make_shared<ov::op::v1::ConvertLike>(context.get_input(2), rhs));
rhs = context.mark_node(std::make_shared<ov::op::v1::Multiply>(converted_alpha, rhs));
}
return {context.mark_node(std::make_shared<opset10::Add>(context.get_input(0), rhs))};
return {context.mark_node(std::make_shared<ov::op::v1::Add>(lhs, rhs))};
};

} // namespace op
Expand Down
37 changes: 24 additions & 13 deletions src/frontends/pytorch/src/op/div.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/floor.hpp"
#include "utils.hpp"

using namespace ov::op;

namespace ov {
namespace frontend {
namespace pytorch {
Expand All @@ -14,21 +19,27 @@ namespace op {
OutputVector translate_div(NodeContext& context) {
auto x = context.get_input(0);
auto y = context.get_input(1);
auto res = context.mark_node(std::make_shared<opset10::Divide>(x, y, true));
std::string rounding_mode = "";
if (!context.input_is_none(2)) {
auto rounding_mode = context.const_input<std::string>(2);
if (rounding_mode == "floor") {
res = context.mark_node(std::make_shared<opset10::Floor>(res));
} else if (rounding_mode == "trunc") {
const auto convert = context.mark_node(std::make_shared<opset10::Convert>(res, element::i64));
res = context.mark_node(std::make_shared<opset10::ConvertLike>(convert, x));
} else {
FRONT_END_OP_CONVERSION_CHECK(false,
"Openvino Pytorch Frontend doesn't support rounding mode ",
rounding_mode,
" for aten::div");
rounding_mode = context.const_input<std::string>(2);
}
if (rounding_mode.empty()) {
// if no rounding mode and both inputs are ints cast BOTH to fp32
const auto x_dtype = x.get_element_type();
const auto y_dtype = y.get_element_type();
if (x_dtype.is_static() && x_dtype.is_integral() && y_dtype.is_static() && y_dtype.is_integral()) {
x = context.mark_node(std::make_shared<v0::Convert>(x, element::f32));
y = context.mark_node(std::make_shared<v0::Convert>(y, element::f32));
mvafin marked this conversation as resolved.
Show resolved Hide resolved
}
}
align_eltwise_input_types(context, x, y, true);
auto res = context.mark_node(std::make_shared<v1::Divide>(x, y, true));
if (rounding_mode == "floor") {
res = context.mark_node(std::make_shared<v0::Floor>(res));
} else if (rounding_mode == "trunc") {
const auto convert = context.mark_node(std::make_shared<v0::Convert>(res, element::i64));
res = context.mark_node(std::make_shared<v1::ConvertLike>(convert, x));
}
return {res};
};

Expand Down
29 changes: 29 additions & 0 deletions src/frontends/pytorch/src/op/mul.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/multiply.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

OutputVector translate_mul(NodeContext& context) {
auto input_size = context.get_input_size();
FRONT_END_OP_CONVERSION_CHECK(input_size >= 2, "Operation has less then 2 inputs.");
auto x = context.get_input(0);
auto y = context.get_input(1);
for (int i = 2; i < input_size; i++) {
FRONT_END_OP_CONVERSION_CHECK(context.input_is_none(i), "Got more inputs than expected.");
}
align_eltwise_input_types(context, x, y);
return {context.mark_node(std::make_shared<ov::op::v1::Multiply>(x, y))};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
13 changes: 9 additions & 4 deletions src/frontends/pytorch/src/op/sub.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/subtract.hpp"
#include "utils.hpp"

using namespace ov::op;

namespace ov {
namespace frontend {
namespace pytorch {
Expand All @@ -14,13 +18,14 @@ namespace op {
OutputVector translate_sub(NodeContext& context) {
auto x = context.get_input(0);
auto y = context.get_input(1);
align_eltwise_input_types(context, x, y);
// default alpha is 1 so no need to multiply if alpha is not provided
if (!context.input_is_none(2)) {
auto alpha = context.get_input(2);
auto casted_alpha = context.mark_node(std::make_shared<opset10::ConvertLike>(alpha, y));
y = context.mark_node(std::make_shared<opset10::Multiply>(casted_alpha, y));
auto casted_alpha = context.mark_node(std::make_shared<v1::ConvertLike>(alpha, y));
y = context.mark_node(std::make_shared<v1::Multiply>(casted_alpha, y));
}
return {context.mark_node(std::make_shared<opset10::Subtract>(x, y))};
return {context.mark_node(std::make_shared<v1::Subtract>(x, y))};
};

} // namespace op
Expand Down
5 changes: 3 additions & 2 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ OP_CONVERTER(translate_masked_fill);
OP_CONVERTER(translate_mean);
OP_CONVERTER(translate_min);
OP_CONVERTER(translate_meshgrid);
OP_CONVERTER(translate_mul);
OP_CONVERTER(translate_neg);
OP_CONVERTER(translate_nonzero);
OP_CONVERTER(translate_norm);
Expand Down Expand Up @@ -228,8 +229,8 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::mm", op::translate_1to1_match_2_inputs<opset10::MatMul>},
{"aten::bmm", op::translate_1to1_match_2_inputs<opset10::MatMul>},
{"aten::matmul", op::translate_1to1_match_2_inputs<opset10::MatMul>},
{"aten::mul", op::translate_1to1_match_2_inputs<opset10::Multiply>},
{"aten::mul_", op::inplace_op<op::translate_1to1_match_2_inputs<opset10::Multiply>>},
{"aten::mul", op::translate_mul},
{"aten::mul_", op::inplace_op<op::translate_mul>},
{"aten::ne", op::translate_1to1_match_2_inputs<opset10::NotEqual>},
{"aten::neg", op::translate_neg},
{"aten::norm", op::translate_norm},
Expand Down
Loading