Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PT FE][Opset14] Replace PT FE AlignTypes with opset14 ConvertPromoteTypes #22770

Merged
merged 40 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
dfeebbd
Replace PT FE AlignTypes with opset14 ConvertPromoteTypes
mmikolajcz Feb 10, 2024
9805aa2
PT FE support PyScalar type promotion
mmikolajcz Feb 15, 2024
6d80606
Add bool as PyScalar
mmikolajcz Feb 15, 2024
df569c2
Merge branch 'master' of https://github.com/openvinotoolkit/openvino …
mmikolajcz Feb 15, 2024
4028e05
Add PyScalar check for distance PT FE op
mmikolajcz Feb 15, 2024
3ca5fcc
FIx ConvertPromoteTypes for torch with dynamic rank
mmikolajcz Feb 20, 2024
1443ccc
Update torch type promotion
mmikolajcz Feb 20, 2024
570947a
Merge branch 'master' of https://github.com/openvinotoolkit/openvino …
mmikolajcz Feb 20, 2024
9d9260b
Modify approach to handle tensor-scalar promotion
mmikolajcz Feb 21, 2024
4e6edbe
Revert int change to i64
mmikolajcz Feb 23, 2024
169e750
Merge branch 'master' of https://github.com/openvinotoolkit/openvino …
mmikolajcz Feb 23, 2024
cc26dd6
Fix simplified_type_interpret for PyScalar
mmikolajcz Feb 23, 2024
b536b74
Fix issue in device tests
mmikolajcz Feb 26, 2024
83fb5d0
Enable add bool test
mmikolajcz Feb 26, 2024
dcbbf8e
Add helpers for type promotion
mmikolajcz Feb 26, 2024
61ae5f2
Use helpers in existing conversions
mmikolajcz Feb 26, 2024
fc95556
Merge branch 'master' of https://github.com/openvinotoolkit/openvino …
mmikolajcz Feb 26, 2024
ee3af7f
Fix compilation error
mmikolajcz Feb 27, 2024
2a3f250
Merge branch 'master' of https://github.com/openvinotoolkit/openvino …
mmikolajcz Feb 28, 2024
f2db323
Merge branch 'master' of https://github.com/openvinotoolkit/openvino …
mmikolajcz Feb 29, 2024
90b574e
Merge branch 'master' of https://github.com/openvinotoolkit/openvino …
mmikolajcz Mar 4, 2024
975b77a
Merge branch 'master' of https://github.com/openvinotoolkit/openvino …
mmikolajcz Mar 5, 2024
4965899
Add missing to floating conversions
mmikolajcz Mar 5, 2024
61112e6
Fix imports
mmikolajcz Mar 5, 2024
c8bd8f3
Improve erfc
mmikolajcz Mar 5, 2024
571d9cb
Remove unused variable
mmikolajcz Mar 5, 2024
e40b7b5
Remove newline
mmikolajcz Mar 8, 2024
a6f4928
Disable transform
mmikolajcz Mar 8, 2024
5345d13
Remove align to lhs
mmikolajcz Mar 8, 2024
c78fdbe
Merge branch 'master' into mateuszm/op/align/torch
mmikolajcz Mar 8, 2024
971ff78
Apply suggestions from code review
mmikolajcz Mar 12, 2024
62b98f3
Merge branch 'master' into mateuszm/op/align/torch
mmikolajcz Mar 12, 2024
83f2e42
Merge branch 'master' into mateuszm/op/align/torch
mmikolajcz Mar 12, 2024
4a93c85
Merge branch 'master' into mateuszm/op/align/torch
mlukasze Mar 12, 2024
2c0f7a9
Merge branch 'master' into mateuszm/op/align/torch
mmikolajcz Mar 13, 2024
b424415
Merge branch 'master' of https://github.com/openvinotoolkit/openvino …
mmikolajcz Mar 14, 2024
fa90edc
Merge branch 'mateuszm/op/align/torch' of https://github.com/mmikolaj…
mmikolajcz Mar 14, 2024
264e783
Fix issue with dynamic types in Mish-4 op
mmikolajcz Mar 14, 2024
26f554b
Improve handling dynamic ranks for mixed Tensor+Scalar cases
mmikolajcz Mar 14, 2024
728127f
Fix code style
mmikolajcz Mar 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,9 @@ def _get_known_type_for_value(self, pt_type):
if pt_type is None:
return OVAny(OVType.dynamic)
# TODO: Don't use str, use native types
if str(pt_type) in pt_to_ov_type_map:
if str(pt_type) in ["int", "float", "bool"]:
return OVAny(DecoderType.PyScalar(OVAny(pt_to_ov_type_map[str(pt_type)])))
elif str(pt_type) in pt_to_ov_type_map:
return OVAny(pt_to_ov_type_map[str(pt_type)])
elif isinstance(pt_type, torch.TensorType):
# Tensor type, parse element type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,6 @@ void regclass_frontend_pytorch_decoder(py::module m) {
def(py::init<>());
py::class_<type::PyNone>(type_module, "PyNone").
def(py::init<>());
py::class_<type::PyScalar>(type_module, "PyScalar").
def(py::init<Any>());
}
4 changes: 4 additions & 0 deletions src/bindings/python/src/pyopenvino/utils/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ py::object from_ov_any(const ov::Any& any) {
return py::cast(any.as<ov::frontend::type::Str>());
} else if (any.is<ov::frontend::type::PyNone>()) {
return py::cast(any.as<ov::frontend::type::PyNone>());
} else if (any.is<ov::frontend::type::PyScalar>()) {
return py::cast(any.as<ov::frontend::type::PyScalar>());
} else {
PyErr_SetString(PyExc_TypeError, "Failed to convert parameter to Python representation!");
return py::cast<py::object>((PyObject*)NULL);
Expand Down Expand Up @@ -402,6 +404,8 @@ ov::Any py_object_to_any(const py::object& py_obj) {
return py::cast<ov::frontend::type::Str>(py_obj);
} else if (py::isinstance<ov::frontend::type::PyNone>(py_obj)) {
return py::cast<ov::frontend::type::PyNone>(py_obj);
} else if (py::isinstance<ov::frontend::type::PyScalar>(py_obj)) {
return py::cast<ov::frontend::type::PyScalar>(py_obj);
// If there is no match fallback to py::object
} else if (py::isinstance<py::object>(py_obj)) {
return py_obj;
Expand Down
49 changes: 28 additions & 21 deletions src/core/src/op/convert_promote_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,6 @@ element::Type evaluate_common_type(const v14::ConvertPromoteTypes* op) {

const auto is_input_0_real = input_0_type.is_real();
const auto is_input_1_real = input_1_type.is_real();
const size_t input_0_bitwidth = input_0_type.bitwidth();
const size_t input_1_bitwidth = input_1_type.bitwidth();

if (is_input_0_real != is_input_1_real) {
// Floating and integer mixed, align to floating
Expand All @@ -109,27 +107,36 @@ element::Type evaluate_common_type(const v14::ConvertPromoteTypes* op) {

} else if (is_input_0_real == is_input_1_real) {
// Type formats are the same (both are either floating or integer).
const auto& input_0_pshape = op->get_input_partial_shape(0);
const auto& input_1_pshape = op->get_input_partial_shape(1);
const auto is_input_0_scalar = input_0_pshape.is_static() && is_scalar(input_0_pshape);
const auto is_input_1_scalar = input_1_pshape.is_static() && is_scalar(input_1_pshape);
if (pytorch_scalar_promotion) {
const auto& input_0_rank = op->get_input_partial_shape(0).rank();
const auto& input_1_rank = op->get_input_partial_shape(1).rank();
if (input_0_rank.is_dynamic() || input_1_rank.is_dynamic()) {
// For pytorch mode, return element::dynamic if ranks affecting output type are dynamic.
return element::dynamic;
}
const auto is_input_0_scalar = input_0_rank.get_length() == 0;
const auto is_input_1_scalar = input_1_rank.get_length() == 0;
if (is_input_0_scalar != is_input_1_scalar) {
// For pytorch mode, when number formats are same, promote to type of non-scalar input.
const auto& target = is_input_0_scalar ? input_1_type : input_0_type;
if (!promote_unsafe) {
// For safe mode, check wether target type has bitwidth able to hold data from scalar type.
const auto& scalar = is_input_0_scalar ? input_0_type : input_1_type;
const auto is_pytorch_promote_safe =
((target.is_signed() == scalar.is_signed() && target.bitwidth() >= scalar.bitwidth()) ||
(target.is_signed() && !scalar.is_signed() && target.bitwidth() * 2 >= scalar.bitwidth()));
NODE_VALIDATION_CHECK(op,
is_pytorch_promote_safe,
"Scalar input cannot be PyTorch-like promoted using safe promotion rules.");
}
return target;
}
}
const auto is_input_0_signed = input_0_type.is_signed();
const auto is_input_1_signed = input_1_type.is_signed();
if (pytorch_scalar_promotion && (is_input_0_scalar != is_input_1_scalar)) {
// For pytorch mode, when number formats are same, promote to type of non-scalar input.
const auto target = is_input_0_scalar ? input_1_type : input_0_type;
if (!promote_unsafe) {
// For safe mode, check wether target type has bitwidth able to hold data from scalar type.
const auto scalar = is_input_0_scalar ? input_0_type : input_1_type;
const auto is_pytorch_promote_safe =
((target.is_signed() == scalar.is_signed() && target.bitwidth() >= scalar.bitwidth()) ||
(target.is_signed() && !scalar.is_signed() && target.bitwidth() * 2 >= scalar.bitwidth()));
NODE_VALIDATION_CHECK(op,
is_pytorch_promote_safe,
"Scalar input cannot be PyTorch-like promoted using safe promotion rules.");
}
return target;
} else if ((is_input_0_signed != is_input_1_signed)) {
const auto input_0_bitwidth = input_0_type.bitwidth();
const auto input_1_bitwidth = input_1_type.bitwidth();
if ((is_input_0_signed != is_input_1_signed)) {
// Signed and unsigned integers are mixed, convert to signed integer with bitwidth able to hold all unsigned
// data. Exception for u64 + integer - either convert to type from `u64_promotion_target` or fail in safe
// mode.
Expand Down
2 changes: 1 addition & 1 deletion src/core/src/op/mish.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void Mish::validate_and_infer_types() {

const auto& data_batch_et = get_input_element_type(0);
NODE_VALIDATION_CHECK(this,
data_batch_et.is_real(),
data_batch_et.is_real() || data_batch_et.is_dynamic(),
"Element must be of floating point type, Got: ",
data_batch_et);

Expand Down
10 changes: 9 additions & 1 deletion src/core/tests/type_prop/convert_promote_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,15 @@ INSTANTIATE_TEST_SUITE_P(type_prop_pytorch_mode,
ov::element::u8,
true,
true,
ov::element::u8,
ov::element::dynamic,
ov::element::f32},
ConvertPromoteTypesTestParams{{},
ov::element::i32,
ov::PartialShape().dynamic(),
ov::element::f16,
true,
true,
ov::element::f16,
ov::element::f32},
ConvertPromoteTypesTestParams{{},
ov::element::f16,
Expand Down
6 changes: 6 additions & 0 deletions src/frontends/common/include/openvino/frontend/decoder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ struct Str {};

struct PyNone {};

struct PyScalar {
PyScalar() = default;
explicit PyScalar(const Any& _element_type) : element_type(_element_type) {}
Any element_type;
};

struct Optional;
struct Dict;
struct NamedTuple;
Expand Down
6 changes: 2 additions & 4 deletions src/frontends/pytorch/src/frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
#include "transformations/fp16_compression/mark_decompression_convert_constant_folding.hpp"
#include "transformations/low_precision/mark_dequantization_subgraph.hpp"
#include "transformations/op_conversions/convert_convertlike.hpp"
#include "transformations/op_conversions/convert_convertpromotetypes.hpp"
#include "transformations/resolve_names_collisions.hpp"
#include "transforms.hpp"
#include "transforms/align_types_removal.hpp"
#include "transforms/append_list_unpack_replacer.hpp"
#include "transforms/aten_cat_replacer.hpp"
#include "transforms/aten_getitem_replacer.hpp"
Expand Down Expand Up @@ -185,7 +185,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
manager.register_pass<ov::pass::MarkCompressedFloatConstants>();
manager.register_pass<ov::pass::ConstantFolding>();

manager.register_pass<ov::frontend::pytorch::pass::AlignTypesRemoval>();
manager.register_pass<ov::pass::ConvertConvertPromoteTypes>();
manager.register_pass<ov::pass::PushConstantToSubgraph>();
manager.register_pass<ov::pass::UnrollIf>();
manager.register_pass<ov::frontend::pytorch::pass::TupleUnpackInBodyReplacer>();
Expand Down Expand Up @@ -218,8 +218,6 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
manager.register_pass<ov::frontend::pytorch::pass::RemovePackingOps>();
manager.register_pass<ov::pass::RemoveMultiSubGraphOpDanglingParamsResults>();
manager.register_pass<ov::pass::ReverseShapeAndTypeInfer>();
// Second pass of AlignTypesRemoval after all converting transformations
manager.register_pass<ov::frontend::pytorch::pass::AlignTypesRemoval>();
manager.register_pass<ov::pass::ResolveNameCollisions>(true);
manager.run_passes(model);

Expand Down
43 changes: 0 additions & 43 deletions src/frontends/pytorch/src/helper_ops/align_types.hpp

This file was deleted.

20 changes: 11 additions & 9 deletions src/frontends/pytorch/src/op/add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,23 @@ using namespace ov::op;

OutputVector translate_add_common(const NodeContext& context, bool inplace) {
num_inputs_check(context, 2, 3);
auto lhs = context.get_input(0);
auto rhs = context.get_input(1);
Output<Node> lhs;
Output<Node> rhs;
auto dtype0 = context.get_input_type(0);
auto dtype1 = context.get_input_type(1);
if (dtype0.is<type::List>() && dtype1.is<type::List>()) {
// aten::add.t(t[] a, t[] b) -> t[]
// Case when two lists gets concatenated
PYTORCH_OP_CONVERSION_CHECK(false, "aten::add is used for concatenation of lists, not possible to convert");
}
if (inplace) {
lhs = context.get_input(0);
rhs = context.get_input(1);
if (lhs.get_element_type().is_dynamic() || lhs.get_element_type() != rhs.get_element_type())
rhs = context.mark_node(std::make_shared<v1::ConvertLike>(rhs, lhs));
} else {
std::tie(lhs, rhs) = get_inputs_with_promoted_types(context, 0, 1);
}

auto left_is_bool = lhs.get_element_type() == ov::element::boolean ||
(dtype0.is<element::Type>() && dtype0.as<element::Type>() == element::boolean);
Expand All @@ -44,12 +52,6 @@ OutputVector translate_add_common(const NodeContext& context, bool inplace) {
return {logical_or};
}

if (inplace) {
if (lhs.get_element_type().is_dynamic() || lhs.get_element_type() != rhs.get_element_type())
rhs = context.mark_node(std::make_shared<v1::ConvertLike>(rhs, lhs));
} else {
align_eltwise_input_types(context, lhs, rhs, true);
}
Output<Node> alpha;
if (!context.input_is_none(2)) {
alpha = context.get_input(2);
Expand Down Expand Up @@ -77,4 +79,4 @@ OutputVector translate_add_(const NodeContext& context) {
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
} // namespace ov
18 changes: 9 additions & 9 deletions src/frontends/pytorch/src/op/bitwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ OutputVector translate_bitwise_not(const NodeContext& context) {

OutputVector translate_bitwise_and(const NodeContext& context) {
num_inputs_check(context, 2, 3);
auto x = context.get_input(0);
auto y = context.get_input(1);
align_eltwise_input_types(context, x, y, false);
Output<Node> x;
Output<Node> y;
std::tie(x, y) = get_inputs_with_promoted_types(context, 0, 1);
auto and_x = context.mark_node(std::make_shared<ov::op::v13::BitwiseAnd>(x, y));
if (!context.input_is_none(2)) {
context.mutate_input(2, and_x);
Expand All @@ -38,9 +38,9 @@ OutputVector translate_bitwise_and(const NodeContext& context) {

OutputVector translate_bitwise_or(const NodeContext& context) {
num_inputs_check(context, 2, 3);
auto x = context.get_input(0);
auto y = context.get_input(1);
align_eltwise_input_types(context, x, y, false);
Output<Node> x;
Output<Node> y;
std::tie(x, y) = get_inputs_with_promoted_types(context, 0, 1);
auto or_x = context.mark_node(std::make_shared<ov::op::v13::BitwiseOr>(x, y));
if (!context.input_is_none(2)) {
context.mutate_input(2, or_x);
Expand All @@ -50,9 +50,9 @@ OutputVector translate_bitwise_or(const NodeContext& context) {

OutputVector translate_bitwise_xor(const NodeContext& context) {
num_inputs_check(context, 2, 3);
auto x = context.get_input(0);
auto y = context.get_input(1);
align_eltwise_input_types(context, x, y, false);
Output<Node> x;
Output<Node> y;
std::tie(x, y) = get_inputs_with_promoted_types(context, 0, 1);
auto xor_x = context.mark_node(std::make_shared<ov::op::v13::BitwiseXor>(x, y));
if (!context.input_is_none(2)) {
context.mutate_input(2, xor_x);
Expand Down
14 changes: 7 additions & 7 deletions src/frontends/pytorch/src/op/cross.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ OutputVector translate_linalg_cross(const NodeContext& context) {
// aten::linalg_cross(Tensor self, Tensor other, int? dim=-1) -> Tensor
// aten::linalg_cross.out(Tensor self, Tensor other, int? dim=-1, *, Tensor(a!) out) -> Tensor(a!)
num_inputs_check(context, 3, 4);
auto self = context.get_input(0);
auto other = context.get_input(1);
align_eltwise_input_types(context, self, other, true);
Output<Node> self;
Output<Node> other;
std::tie(self, other) = get_inputs_with_promoted_types(context, 0, 1);
auto const_minus_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
Output<Node> dim;
if (context.input_is_none(2)) {
Expand All @@ -59,9 +59,9 @@ OutputVector translate_cross(const NodeContext& context) {
// aten::cross(Tensor self, Tensor other, int? dim=None) -> Tensor
// aten::cross.out(Tensor self, Tensor other, int? dim=None, *, Tensor(a!) out) -> Tensor(a!)
num_inputs_check(context, 3, 4);
auto self = context.get_input(0);
auto other = context.get_input(1);
align_eltwise_input_types(context, self, other, true);
Output<Node> self;
Output<Node> other;
std::tie(self, other) = get_inputs_with_promoted_types(context, 0, 1);
Output<Node> dim;
if (context.input_is_none(2)) {
// If dim is not given, it defaults to the first dimension found with the size 3
Expand Down Expand Up @@ -98,4 +98,4 @@ OutputVector translate_cross(const NodeContext& context) {
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
} // namespace ov
4 changes: 2 additions & 2 deletions src/frontends/pytorch/src/op/distance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ Output<Node> pairwise_distance(const NodeContext& context,
auto p_plus_eps = context.mark_node(std::make_shared<v1::Add>(p, eps));
auto inv_p = context.mark_node(std::make_shared<v1::Divide>(one, p_plus_eps));
auto minus_one = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
align_eltwise_input_types(context, x, y, true);
align_eltwise_input_types(context, x, y, is_python_scalar_input(context, 0), is_python_scalar_input(context, 1));
auto x_y_diff = context.mark_node(std::make_shared<v1::Subtract>(x, y));
auto x_y_diff_in_p_power = context.mark_node(std::make_shared<v1::Power>(x_y_diff, p));
auto summation = context.mark_node(std::make_shared<v1::ReduceSum>(x_y_diff_in_p_power, minus_one, keepdim));
Expand Down Expand Up @@ -91,4 +91,4 @@ OutputVector translate_pairwise_distance(const NodeContext& context) {
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
} // namespace ov
6 changes: 5 additions & 1 deletion src/frontends/pytorch/src/op/div.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@ OutputVector translate_div_common(const NodeContext& context,
if (x.get_element_type().is_dynamic() || x.get_element_type() != y.get_element_type())
y = context.mark_node(std::make_shared<v1::ConvertLike>(y, x));
} else {
align_eltwise_input_types(context, x, y, true);
align_eltwise_input_types(context,
x,
y,
is_python_scalar_input(context, 0),
is_python_scalar_input(context, 1));
}
auto res = context.mark_node(std::make_shared<v1::Divide>(x, y, true));
// TODO: ticket 103296; Temporarily disable ConvertDivide transformation
Expand Down
7 changes: 3 additions & 4 deletions src/frontends/pytorch/src/op/erfc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@ OutputVector translate_erfc(const NodeContext& context) {
// aten::erf(Tensor self) -> Tensor
// aten::erf.out(Tensor self, Tensor(!a) out) -> Tensor(!a)
num_inputs_check(context, 1, 2);
auto x = context.get_input(0);
auto x = get_input_with_floating_type(context, 0);

// create 'ones' to use to calculate complementary of Erf output
auto ones = context.mark_node(make_shared<v0::Constant>(element::f32, Shape{}, 1.0f))->output(0);

// align data types of input 'x' and ones
align_eltwise_input_types(context, x, ones);
ones = context.mark_node(std::make_shared<v1::ConvertLike>(ones, x));

// apply Erf to the input tensor 'x'
auto y = context.mark_node(make_shared<v0::Erf>(x));
Expand All @@ -42,4 +41,4 @@ OutputVector translate_erfc(const NodeContext& context) {
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
} // namespace ov
Loading
Loading