Skip to content

Commit

Permalink
[Opset14][CORE] ConvertPromoteTypes op core implementation (#22566)
Browse files Browse the repository at this point in the history
### Details:
- *Implement ConvertPromoteTypes core - operation that aligns input
types into one common one based on type promotion rules*
- *Implement decomposition from ConvertPromoteTypes to pair of Convert
ops*
- *Operator designed for PT frontend, but rules extend to match also
TF-like promotion rules*
 - *PR Contains tests for shape/type validation and visitor tests*

### Tickets:
 - *129198*
 - *129203*
 - *129202*
  • Loading branch information
mmikolajcz authored Feb 9, 2024
1 parent b48b799 commit 733ed1b
Show file tree
Hide file tree
Showing 12 changed files with 1,101 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/graph_rewrite.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {

class TRANSFORMATIONS_API ConvertConvertPromoteTypes;

} // namespace pass
} // namespace ov

/// \brief Transformation to replace ConvertPromoteTypes with pair of Convert ops for each input to evaluated common
/// element type.
class ov::pass::ConvertConvertPromoteTypes : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ConvertConvertPromoteTypes", "0");
ConvertConvertPromoteTypes();
};
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
#include "transformations/op_conversions/convert_bitwise_to_logical_bool.hpp"
#include "transformations/op_conversions/convert_broadcast_to_tiles.hpp"
#include "transformations/op_conversions/convert_convertlike.hpp"
#include "transformations/op_conversions/convert_convertpromotetypes.hpp"
#include "transformations/op_conversions/convert_deformable_conv_v8_to_v1.hpp"
#include "transformations/op_conversions/convert_depth_to_space.hpp"
#include "transformations/op_conversions/convert_divide.hpp"
Expand Down Expand Up @@ -164,6 +165,7 @@ bool ov::pass::CommonOptimizations::run_on_model(const std::shared_ptr<ov::Model
ADD_MATCHER(decomp, ConvertDivide)
ADD_MATCHER(decomp, ConvertDepthToSpace)
ADD_MATCHER(decomp, ConvertSpaceToDepth)
ADD_MATCHER(decomp, ConvertConvertPromoteTypes)
ADD_MATCHER(decomp, ConvertConvertLike)
ADD_MATCHER(decomp, BatchNormDecomposition)
ADD_MATCHER(decomp, GroupNormalizationDecomposition)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/op_conversions/convert_convertpromotetypes.hpp"

#include "itt.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/convert_promote_types.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"

ov::pass::ConvertConvertPromoteTypes::ConvertConvertPromoteTypes() {
MATCHER_SCOPE(ConvertConvertPromoteTypes);

auto has_static_defined_type = [](const Output<Node>& output) -> bool {
return !pattern::type_matches_any({element::dynamic, element::undefined})(output);
};
auto convert_promote_types = pattern::wrap_type<ov::op::v14::ConvertPromoteTypes>(has_static_defined_type);

matcher_pass_callback callback = [](pattern::Matcher& m) {
auto convert_promote_types = std::dynamic_pointer_cast<ov::op::v14::ConvertPromoteTypes>(m.get_match_root());
if (!convert_promote_types) {
return false;
}
const element::Type& dest_type = convert_promote_types->get_output_element_type(0);
const auto friendly_name = convert_promote_types->get_friendly_name();
NodeRegistry node_registry;
const auto out0 = node_registry.make<ov::op::v0::Convert>(convert_promote_types->input_value(0), dest_type);
const auto out1 = node_registry.make<ov::op::v0::Convert>(convert_promote_types->input_value(1), dest_type);
out0->set_friendly_name(convert_promote_types->get_input_node_shared_ptr(0)->get_friendly_name() + "/" +
friendly_name + ".0");
out1->set_friendly_name(convert_promote_types->get_input_node_shared_ptr(1)->get_friendly_name() + "/" +
friendly_name + ".1");
copy_runtime_info(convert_promote_types, node_registry.get());
replace_node(convert_promote_types, {out0, out1});
return true;
};

auto m = std::make_shared<pattern::Matcher>(convert_promote_types, matcher_name);
this->register_matcher(m, callback);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/op_conversions/convert_convertpromotetypes.hpp"

#include "common_test_utils/ov_test_utils.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/convert_promote_types.hpp"
#include "openvino/op/parameter.hpp"
#include "openvino/pass/manager.hpp"

namespace {

struct ConvertPromoteTypesTestParams {
ov::PartialShape lhs_shape;
ov::element::Type lhs_type;
ov::PartialShape rhs_shape;
ov::element::Type rhs_type;
ov::element::Type expected_type;
};

class ConvertConvertPromoteTypesTest : public TransformationTestsF,
public testing::WithParamInterface<ConvertPromoteTypesTestParams> {
private:
void SetUp() override {
TransformationTestsF::SetUp();
const auto& parameters = GetParam();
const auto& lhsType = parameters.lhs_type;
const auto& lhsShape = parameters.lhs_shape;
const auto& rhsType = parameters.rhs_type;
const auto& rhsShape = parameters.rhs_shape;
const auto& alignType = parameters.expected_type;
model = transform(lhsShape, lhsType, rhsShape, rhsType);
model_ref = reference(lhsShape, lhsType, rhsShape, rhsType, alignType);
manager.register_pass<ov::pass::ConvertConvertPromoteTypes>();
}

protected:
static std::shared_ptr<ov::Model> transform(const ov::PartialShape& lhsShape,
const ov::element::Type& lhsType,
const ov::PartialShape& rhsShape,
const ov::element::Type& rhsType) {
const auto lhs = std::make_shared<ov::op::v0::Parameter>(lhsType, lhsShape);
const auto rhs = std::make_shared<ov::op::v0::Parameter>(rhsType, rhsShape);
const auto convert_promote_types = std::make_shared<ov::op::v14::ConvertPromoteTypes>(lhs, rhs, true);
return std::make_shared<ov::Model>(convert_promote_types->outputs(), ov::ParameterVector{lhs, rhs}, "Actual");
}

static std::shared_ptr<ov::Model> reference(const ov::PartialShape& lhsShape,
const ov::element::Type& lhsType,
const ov::PartialShape& rhsShape,
const ov::element::Type& rhsType,
const ov::element::Type& alignType) {
const auto lhs = std::make_shared<ov::op::v0::Parameter>(lhsType, lhsShape);
const auto rhs = std::make_shared<ov::op::v0::Parameter>(rhsType, rhsShape);
const auto lhs_converted = std::make_shared<ov::op::v0::Convert>(lhs, alignType);
const auto rhs_converted = std::make_shared<ov::op::v0::Convert>(rhs, alignType);
return std::make_shared<ov::Model>(ov::NodeVector{lhs_converted, rhs_converted},
ov::ParameterVector{lhs, rhs},
"Reference");
}
};
INSTANTIATE_TEST_SUITE_P(
ConvertPromoteTypesDecomposition,
ConvertConvertPromoteTypesTest,
testing::Values(
ConvertPromoteTypesTestParams{ov::PartialShape::dynamic(),
ov::element::f32,
ov::PartialShape::dynamic(),
ov::element::f32,
ov::element::f32},
ConvertPromoteTypesTestParams{ov::PartialShape::dynamic(),
ov::element::u16,
{5, 6, 7},
ov::element::i16,
ov::element::i32},
ConvertPromoteTypesTestParams{{1, 2, 3, 4},
ov::element::u16,
ov::PartialShape::dynamic(),
ov::element::f32,
ov::element::f32},
ConvertPromoteTypesTestParams{{1, {3, 7}, -1, 4},
ov::element::f8e4m3,
{0, 6, 7},
ov::element::f8e5m2,
ov::element::f16},
ConvertPromoteTypesTestParams{{}, ov::element::bf16, {5, 6, 7}, ov::element::f16, ov::element::f32},
ConvertPromoteTypesTestParams{{1, 2, 3, 4}, ov::element::u16, {}, ov::element::boolean, ov::element::u16},
ConvertPromoteTypesTestParams{{}, ov::element::u16, {}, ov::element::u1, ov::element::u16},
ConvertPromoteTypesTestParams{{-1}, ov::element::u64, {1}, ov::element::i16, ov::element::f32},
ConvertPromoteTypesTestParams{{1, 2, 3, 4},
ov::element::boolean,
{5, 6, 7},
ov::element::boolean,
ov::element::boolean},
ConvertPromoteTypesTestParams{{1, 2, 3, 4}, ov::element::i64, {5, 6, 7}, ov::element::f16, ov::element::f16}));
TEST_P(ConvertConvertPromoteTypesTest, CompareFunctions) {}
} // namespace
67 changes: 67 additions & 0 deletions src/core/include/openvino/op/convert_promote_types.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/op/op.hpp"

namespace ov {
namespace op {
namespace v14 {
/// \brief Elementwise operation that promote and convert input types to one common datatype.
/// \ingroup ov_ops_cpp_api
class OPENVINO_API ConvertPromoteTypes : public Op {
public:
OPENVINO_OP("ConvertPromoteTypes", "opset14", op::Op);

/// \brief Constructs operation that promote and convert input types to one common datatype.
ConvertPromoteTypes() = default;
/// \brief Constructs operation that promote and convert input types to one common datatype.
/// \param input_0 Node with datatype to be promoted.
/// \param input_1 Node with datatype to be promoted.
/// \param promote_unsafe Bool attribute whether to allow promotions that might result in bit-widening,
/// precision loss and undefined behaviors.
/// \param pytorch_scalar_promotion Bool attribute whether to promote scalar input to type provided by non-scalar
/// input when number format is matching. \param u64_integer_promotion_target Element type attribute to select
/// promotion result for u64 and signed integers.
ConvertPromoteTypes(const Output<Node>& input_0,
const Output<Node>& input_1,
const bool promote_unsafe = false,
const bool pytorch_scalar_promotion = false,
const element::Type& u64_integer_promotion_target = element::f32);
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;

std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;

/// \brief Get bool attribute whether to promote scalar input to type provided by non-scalar input when number
/// format is matching.
bool get_pytorch_scalar_promotion() const;

/// \brief Set bool attribute whether to promote scalar input to type provided by non-scalar input when number
/// format is matching.
void set_pytorch_scalar_promotion(bool pytorch_scalar_promotion);

/// \brief Get bool attribute whether to allow promotions that might result in bit-widening, precision loss and
/// undefined behaviors.
bool get_promote_unsafe() const;

/// \brief Set bool attribute whether to allow promotions that might result in bit-widening, precision loss and
/// undefined behaviors.
void set_promote_unsafe(bool promote_unsafe);

/// \brief Get element type attribute to select promotion result for u64 and signed integers.
const element::Type& get_u64_integer_promotion_target() const;

/// \brief Set element type attribute to select promotion result for u64 and signed integers.
void set_u64_integer_promotion_target(const element::Type& u64_integer_promotion_target);

private:
bool m_promote_unsafe = false;
bool m_pytorch_scalar_promotion = false;
element::Type m_u64_integer_promotion_target = element::f32;
};
} // namespace v14
} // namespace op
} // namespace ov
1 change: 1 addition & 0 deletions src/core/include/openvino/op/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/convert_promote_types.hpp"
#include "openvino/op/convolution.hpp"
#include "openvino/op/cos.hpp"
#include "openvino/op/cosh.hpp"
Expand Down
3 changes: 2 additions & 1 deletion src/core/include/openvino/opsets/opset14_tbl.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2018-2023 Intel Corporation
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

Expand Down Expand Up @@ -219,3 +219,4 @@ _OPENVINO_OP_REG(ScaledDotProductAttention, ov::op::v13)
_OPENVINO_OP_REG(FakeConvert, ov::op::v13)

// New operations added in opset14
_OPENVINO_OP_REG(ConvertPromoteTypes, ov::op::v14)
Loading

0 comments on commit 733ed1b

Please sign in to comment.