-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Opset14][CORE] ConvertPromoteTypes op core implementation (#22566)
### Details: - *Implement ConvertPromoteTypes core - operation that aligns input types into one common one based on type promotion rules* - *Implement decomposition from ConvertPromoteTypes to pair of Convert ops* - *Operator designed for PT frontend, but rules extend to match also TF-like promotion rules* - *PR Contains tests for shape/type validation and visitor tests* ### Tickets: - *129198* - *129203* - *129202*
- Loading branch information
1 parent
b48b799
commit 733ed1b
Showing
12 changed files
with
1,101 additions
and
2 deletions.
There are no files selected for viewing
24 changes: 24 additions & 0 deletions
24
...on/transformations/include/transformations/op_conversions/convert_convertpromotetypes.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/pass/graph_rewrite.hpp" | ||
#include "transformations_visibility.hpp" | ||
|
||
namespace ov { | ||
namespace pass { | ||
|
||
class TRANSFORMATIONS_API ConvertConvertPromoteTypes; | ||
|
||
} // namespace pass | ||
} // namespace ov | ||
|
||
/// \brief Transformation to replace ConvertPromoteTypes with pair of Convert ops for each input to evaluated common | ||
/// element type. | ||
class ov::pass::ConvertConvertPromoteTypes : public ov::pass::MatcherPass { | ||
public: | ||
OPENVINO_RTTI("ConvertConvertPromoteTypes", "0"); | ||
ConvertConvertPromoteTypes(); | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
42 changes: 42 additions & 0 deletions
42
...common/transformations/src/transformations/op_conversions/convert_convertpromotetypes.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "transformations/op_conversions/convert_convertpromotetypes.hpp" | ||
|
||
#include "itt.hpp" | ||
#include "openvino/core/rt_info.hpp" | ||
#include "openvino/op/convert.hpp" | ||
#include "openvino/op/convert_promote_types.hpp" | ||
#include "openvino/pass/pattern/op/wrap_type.hpp" | ||
|
||
ov::pass::ConvertConvertPromoteTypes::ConvertConvertPromoteTypes() { | ||
MATCHER_SCOPE(ConvertConvertPromoteTypes); | ||
|
||
auto has_static_defined_type = [](const Output<Node>& output) -> bool { | ||
return !pattern::type_matches_any({element::dynamic, element::undefined})(output); | ||
}; | ||
auto convert_promote_types = pattern::wrap_type<ov::op::v14::ConvertPromoteTypes>(has_static_defined_type); | ||
|
||
matcher_pass_callback callback = [](pattern::Matcher& m) { | ||
auto convert_promote_types = std::dynamic_pointer_cast<ov::op::v14::ConvertPromoteTypes>(m.get_match_root()); | ||
if (!convert_promote_types) { | ||
return false; | ||
} | ||
const element::Type& dest_type = convert_promote_types->get_output_element_type(0); | ||
const auto friendly_name = convert_promote_types->get_friendly_name(); | ||
NodeRegistry node_registry; | ||
const auto out0 = node_registry.make<ov::op::v0::Convert>(convert_promote_types->input_value(0), dest_type); | ||
const auto out1 = node_registry.make<ov::op::v0::Convert>(convert_promote_types->input_value(1), dest_type); | ||
out0->set_friendly_name(convert_promote_types->get_input_node_shared_ptr(0)->get_friendly_name() + "/" + | ||
friendly_name + ".0"); | ||
out1->set_friendly_name(convert_promote_types->get_input_node_shared_ptr(1)->get_friendly_name() + "/" + | ||
friendly_name + ".1"); | ||
copy_runtime_info(convert_promote_types, node_registry.get()); | ||
replace_node(convert_promote_types, {out0, out1}); | ||
return true; | ||
}; | ||
|
||
auto m = std::make_shared<pattern::Matcher>(convert_promote_types, matcher_name); | ||
this->register_matcher(m, callback); | ||
} |
99 changes: 99 additions & 0 deletions
99
src/common/transformations/tests/op_conversions/convert_convertpromotetypes_test.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "transformations/op_conversions/convert_convertpromotetypes.hpp" | ||
|
||
#include "common_test_utils/ov_test_utils.hpp" | ||
#include "openvino/op/convert.hpp" | ||
#include "openvino/op/convert_promote_types.hpp" | ||
#include "openvino/op/parameter.hpp" | ||
#include "openvino/pass/manager.hpp" | ||
|
||
namespace { | ||
|
||
struct ConvertPromoteTypesTestParams { | ||
ov::PartialShape lhs_shape; | ||
ov::element::Type lhs_type; | ||
ov::PartialShape rhs_shape; | ||
ov::element::Type rhs_type; | ||
ov::element::Type expected_type; | ||
}; | ||
|
||
class ConvertConvertPromoteTypesTest : public TransformationTestsF, | ||
public testing::WithParamInterface<ConvertPromoteTypesTestParams> { | ||
private: | ||
void SetUp() override { | ||
TransformationTestsF::SetUp(); | ||
const auto& parameters = GetParam(); | ||
const auto& lhsType = parameters.lhs_type; | ||
const auto& lhsShape = parameters.lhs_shape; | ||
const auto& rhsType = parameters.rhs_type; | ||
const auto& rhsShape = parameters.rhs_shape; | ||
const auto& alignType = parameters.expected_type; | ||
model = transform(lhsShape, lhsType, rhsShape, rhsType); | ||
model_ref = reference(lhsShape, lhsType, rhsShape, rhsType, alignType); | ||
manager.register_pass<ov::pass::ConvertConvertPromoteTypes>(); | ||
} | ||
|
||
protected: | ||
static std::shared_ptr<ov::Model> transform(const ov::PartialShape& lhsShape, | ||
const ov::element::Type& lhsType, | ||
const ov::PartialShape& rhsShape, | ||
const ov::element::Type& rhsType) { | ||
const auto lhs = std::make_shared<ov::op::v0::Parameter>(lhsType, lhsShape); | ||
const auto rhs = std::make_shared<ov::op::v0::Parameter>(rhsType, rhsShape); | ||
const auto convert_promote_types = std::make_shared<ov::op::v14::ConvertPromoteTypes>(lhs, rhs, true); | ||
return std::make_shared<ov::Model>(convert_promote_types->outputs(), ov::ParameterVector{lhs, rhs}, "Actual"); | ||
} | ||
|
||
static std::shared_ptr<ov::Model> reference(const ov::PartialShape& lhsShape, | ||
const ov::element::Type& lhsType, | ||
const ov::PartialShape& rhsShape, | ||
const ov::element::Type& rhsType, | ||
const ov::element::Type& alignType) { | ||
const auto lhs = std::make_shared<ov::op::v0::Parameter>(lhsType, lhsShape); | ||
const auto rhs = std::make_shared<ov::op::v0::Parameter>(rhsType, rhsShape); | ||
const auto lhs_converted = std::make_shared<ov::op::v0::Convert>(lhs, alignType); | ||
const auto rhs_converted = std::make_shared<ov::op::v0::Convert>(rhs, alignType); | ||
return std::make_shared<ov::Model>(ov::NodeVector{lhs_converted, rhs_converted}, | ||
ov::ParameterVector{lhs, rhs}, | ||
"Reference"); | ||
} | ||
}; | ||
INSTANTIATE_TEST_SUITE_P( | ||
ConvertPromoteTypesDecomposition, | ||
ConvertConvertPromoteTypesTest, | ||
testing::Values( | ||
ConvertPromoteTypesTestParams{ov::PartialShape::dynamic(), | ||
ov::element::f32, | ||
ov::PartialShape::dynamic(), | ||
ov::element::f32, | ||
ov::element::f32}, | ||
ConvertPromoteTypesTestParams{ov::PartialShape::dynamic(), | ||
ov::element::u16, | ||
{5, 6, 7}, | ||
ov::element::i16, | ||
ov::element::i32}, | ||
ConvertPromoteTypesTestParams{{1, 2, 3, 4}, | ||
ov::element::u16, | ||
ov::PartialShape::dynamic(), | ||
ov::element::f32, | ||
ov::element::f32}, | ||
ConvertPromoteTypesTestParams{{1, {3, 7}, -1, 4}, | ||
ov::element::f8e4m3, | ||
{0, 6, 7}, | ||
ov::element::f8e5m2, | ||
ov::element::f16}, | ||
ConvertPromoteTypesTestParams{{}, ov::element::bf16, {5, 6, 7}, ov::element::f16, ov::element::f32}, | ||
ConvertPromoteTypesTestParams{{1, 2, 3, 4}, ov::element::u16, {}, ov::element::boolean, ov::element::u16}, | ||
ConvertPromoteTypesTestParams{{}, ov::element::u16, {}, ov::element::u1, ov::element::u16}, | ||
ConvertPromoteTypesTestParams{{-1}, ov::element::u64, {1}, ov::element::i16, ov::element::f32}, | ||
ConvertPromoteTypesTestParams{{1, 2, 3, 4}, | ||
ov::element::boolean, | ||
{5, 6, 7}, | ||
ov::element::boolean, | ||
ov::element::boolean}, | ||
ConvertPromoteTypesTestParams{{1, 2, 3, 4}, ov::element::i64, {5, 6, 7}, ov::element::f16, ov::element::f16})); | ||
TEST_P(ConvertConvertPromoteTypesTest, CompareFunctions) {} | ||
} // namespace |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/op/op.hpp" | ||
|
||
namespace ov { | ||
namespace op { | ||
namespace v14 { | ||
/// \brief Elementwise operation that promote and convert input types to one common datatype. | ||
/// \ingroup ov_ops_cpp_api | ||
class OPENVINO_API ConvertPromoteTypes : public Op { | ||
public: | ||
OPENVINO_OP("ConvertPromoteTypes", "opset14", op::Op); | ||
|
||
/// \brief Constructs operation that promote and convert input types to one common datatype. | ||
ConvertPromoteTypes() = default; | ||
/// \brief Constructs operation that promote and convert input types to one common datatype. | ||
/// \param input_0 Node with datatype to be promoted. | ||
/// \param input_1 Node with datatype to be promoted. | ||
/// \param promote_unsafe Bool attribute whether to allow promotions that might result in bit-widening, | ||
/// precision loss and undefined behaviors. | ||
/// \param pytorch_scalar_promotion Bool attribute whether to promote scalar input to type provided by non-scalar | ||
/// input when number format is matching. \param u64_integer_promotion_target Element type attribute to select | ||
/// promotion result for u64 and signed integers. | ||
ConvertPromoteTypes(const Output<Node>& input_0, | ||
const Output<Node>& input_1, | ||
const bool promote_unsafe = false, | ||
const bool pytorch_scalar_promotion = false, | ||
const element::Type& u64_integer_promotion_target = element::f32); | ||
void validate_and_infer_types() override; | ||
bool visit_attributes(AttributeVisitor& visitor) override; | ||
|
||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override; | ||
|
||
/// \brief Get bool attribute whether to promote scalar input to type provided by non-scalar input when number | ||
/// format is matching. | ||
bool get_pytorch_scalar_promotion() const; | ||
|
||
/// \brief Set bool attribute whether to promote scalar input to type provided by non-scalar input when number | ||
/// format is matching. | ||
void set_pytorch_scalar_promotion(bool pytorch_scalar_promotion); | ||
|
||
/// \brief Get bool attribute whether to allow promotions that might result in bit-widening, precision loss and | ||
/// undefined behaviors. | ||
bool get_promote_unsafe() const; | ||
|
||
/// \brief Set bool attribute whether to allow promotions that might result in bit-widening, precision loss and | ||
/// undefined behaviors. | ||
void set_promote_unsafe(bool promote_unsafe); | ||
|
||
/// \brief Get element type attribute to select promotion result for u64 and signed integers. | ||
const element::Type& get_u64_integer_promotion_target() const; | ||
|
||
/// \brief Set element type attribute to select promotion result for u64 and signed integers. | ||
void set_u64_integer_promotion_target(const element::Type& u64_integer_promotion_target); | ||
|
||
private: | ||
bool m_promote_unsafe = false; | ||
bool m_pytorch_scalar_promotion = false; | ||
element::Type m_u64_integer_promotion_target = element::f32; | ||
}; | ||
} // namespace v14 | ||
} // namespace op | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.