Skip to content

Commit

Permalink
[Transformations] Add Squeeze-15 downgrade transformation (#27286)
Browse files Browse the repository at this point in the history
### Details:
- *Add Squeeze-15 downgrade transformation to Squeeze-0 for compatible
attribute*
 - *...*

### Tickets:
 - *CVS-154027*

### PR requires
[PR-26995](#26995) to be
merged

---------

Co-authored-by: Michal Lukaszewski <[email protected]>
  • Loading branch information
mmikolajcz and mlukasze authored Oct 31, 2024
1 parent 44b86a8 commit 86083e0
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/matcher_pass.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {
/**
* @ingroup ov_transformation_common_api
* @brief Converts Squeeze v15 to Squeeze v0.
*/
class TRANSFORMATIONS_API ConvertSqueeze15ToSqueeze0 : public MatcherPass {
public:
OPENVINO_RTTI("ConvertSqueeze15ToSqueeze0", "0");
ConvertSqueeze15ToSqueeze0();
};

} // namespace pass
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
#include "transformations/op_conversions/convert_softmax_downgrade.hpp"
#include "transformations/op_conversions/convert_softmax_upgrade.hpp"
#include "transformations/op_conversions/convert_space_to_depth.hpp"
#include "transformations/op_conversions/convert_squeeze15_downgrade.hpp"
#include "transformations/op_conversions/convert_subtract.hpp"
#include "transformations/op_conversions/convert_topk11_downgrade.hpp"
#include "transformations/op_conversions/convert_xor_to_logical_xor.hpp"
Expand Down Expand Up @@ -235,6 +236,7 @@ bool ov::pass::CommonOptimizations::run_on_model(const std::shared_ptr<ov::Model
REGISTER_PASS(manager, ConvertEmbeddingBagPacked15ToEmbeddingBagPackedSum3)
REGISTER_PASS(manager, ConvertScatterNDUpdate15ToScatterNDUpdate3)
REGISTER_PASS(manager, ConvertSliceScatter)
REGISTER_PASS(manager, ConvertSqueeze15ToSqueeze0)

auto fq_fusions = manager.register_pass<GraphRewrite>();
ADD_MATCHER(fq_fusions, FakeQuantizeMulFusion)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/op_conversions/convert_squeeze15_downgrade.hpp"

#include "itt.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/op/squeeze.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"

ov::pass::ConvertSqueeze15ToSqueeze0::ConvertSqueeze15ToSqueeze0() {
MATCHER_SCOPE(ConvertSqueeze15ToSqueeze0);

const auto& squeeze_v15_pattern = pattern::wrap_type<ov::op::v15::Squeeze>();

const matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](pattern::Matcher& m) {
const auto& squeeze_v15 = ov::as_type_ptr<ov::op::v15::Squeeze>(m.get_match_root());
if (!squeeze_v15 || transformation_callback(squeeze_v15)) {
return false;
}
std::shared_ptr<op::v0::Squeeze> squeeze_v0;
if (squeeze_v15->get_input_size() == 1) {
squeeze_v0 = std::make_shared<op::v0::Squeeze>(squeeze_v15->input_value(0));
} else if (squeeze_v15->get_input_size() == 2 && !squeeze_v15->get_allow_axis_skip()) {
squeeze_v0 = std::make_shared<op::v0::Squeeze>(squeeze_v15->input_value(0), squeeze_v15->input_value(1));
} else {
return false;
}
squeeze_v0->set_friendly_name(squeeze_v15->get_friendly_name());
copy_runtime_info(squeeze_v15, squeeze_v0);
replace_node(squeeze_v15, squeeze_v0);

return true;
};

auto m = std::make_shared<pattern::Matcher>(squeeze_v15_pattern, matcher_name);
register_matcher(m, callback);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/op_conversions/convert_squeeze15_downgrade.hpp"

#include <gtest/gtest.h>

#include <memory>

#include "common_test_utils/ov_test_utils.hpp"
#include "openvino/opsets/opset1.hpp"
#include "openvino/opsets/opset15.hpp"
#include "openvino/pass/manager.hpp"
#include "transformations/utils/utils.hpp"
using namespace ov;
using namespace testing;

namespace {

enum class IndicesMode { NONE, CONST, PARAM };

std::shared_ptr<ov::Model> create_v15_model(const IndicesMode indices_mode,
const std::vector<int> indices_const_val,
const bool allow_axis_skip) {
const PartialShape data_shape{-1, {2, 5}, 1, {1, 5}, 4};
const auto& data = std::make_shared<ov::opset15::Parameter>(ov::element::f32, data_shape);
ov::ParameterVector params = {data};
std::shared_ptr<op::v15::Squeeze> squeeze;
if (indices_mode == IndicesMode::NONE) {
squeeze = std::make_shared<ov::opset15::Squeeze>(data, allow_axis_skip);
} else if (indices_mode == IndicesMode::PARAM) {
const auto& indices =
std::make_shared<ov::opset15::Parameter>(ov::element::i32, PartialShape({data_shape.rank()}));
params.push_back(indices);
squeeze = std::make_shared<ov::opset15::Squeeze>(data, indices, allow_axis_skip);
} else if (indices_mode == IndicesMode::CONST) {
const auto& indices =
ov::opset15::Constant::create(ov::element::i32, Shape({indices_const_val.size()}), indices_const_val);
squeeze = std::make_shared<ov::opset15::Squeeze>(data, indices, allow_axis_skip);
}
squeeze->set_friendly_name("squeeze15");
return std::make_shared<ov::Model>(squeeze->outputs(), params);
}

std::shared_ptr<ov::Model> create_v1_model(const IndicesMode indices_mode, const std::vector<int> indices_const_val) {
const PartialShape data_shape{-1, {2, 5}, 1, {1, 5}, 4};
const auto& data = std::make_shared<ov::opset1::Parameter>(ov::element::f32, data_shape);
ov::ParameterVector params = {data};
std::shared_ptr<op::v0::Squeeze> squeeze;
if (indices_mode == IndicesMode::NONE) {
squeeze = std::make_shared<ov::opset1::Squeeze>(data);
} else if (indices_mode == IndicesMode::PARAM) {
const auto& indices =
std::make_shared<ov::opset1::Parameter>(ov::element::i32, PartialShape({data_shape.rank()}));
params.push_back(indices);
squeeze = std::make_shared<ov::opset1::Squeeze>(data, indices);
} else if (indices_mode == IndicesMode::CONST) {
const auto& indices =
ov::opset1::Constant::create(ov::element::i32, Shape({indices_const_val.size()}), indices_const_val);
squeeze = std::make_shared<ov::opset1::Squeeze>(data, indices);
}
squeeze->set_friendly_name("squeeze15");
return std::make_shared<ov::Model>(squeeze->outputs(), params);
}

} // namespace

TEST_F(TransformationTestsF, ConvertSqueeze15ToSqueeze1_no_indices_no_skip) {
manager.register_pass<ov::pass::ConvertSqueeze15ToSqueeze0>();
model = create_v15_model(IndicesMode::NONE, {}, false);
model_ref = create_v1_model(IndicesMode::NONE, {});
EXPECT_EQ(model->output(0).get_partial_shape(), model_ref->output(0).get_partial_shape());
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
comparator.enable(FunctionsComparator::CmpValues::NAMES);
}

TEST_F(TransformationTestsF, ConvertSqueeze15ToSqueeze1_no_indices_skip) {
manager.register_pass<ov::pass::ConvertSqueeze15ToSqueeze0>();
model = create_v15_model(IndicesMode::NONE, {}, true);
model_ref = create_v1_model(IndicesMode::NONE, {});
EXPECT_EQ(model->output(0).get_partial_shape(), model_ref->output(0).get_partial_shape());
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
comparator.enable(FunctionsComparator::CmpValues::NAMES);
}

TEST_F(TransformationTestsF, ConvertSqueeze15ToSqueeze1_const_indices_no_skip) {
manager.register_pass<ov::pass::ConvertSqueeze15ToSqueeze0>();
model = create_v15_model(IndicesMode::CONST, {0, -4, 3}, false);
model_ref = create_v1_model(IndicesMode::CONST, {0, -4, 3});
EXPECT_EQ(model->output(0).get_partial_shape(), model_ref->output(0).get_partial_shape());
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
comparator.enable(FunctionsComparator::CmpValues::NAMES);
}

TEST_F(TransformationTestsF, ConvertSqueeze15ToSqueeze1_dynamic_indices_no_skip) {
manager.register_pass<ov::pass::ConvertSqueeze15ToSqueeze0>();
model = create_v15_model(IndicesMode::PARAM, {}, false);
model_ref = create_v1_model(IndicesMode::PARAM, {});
EXPECT_EQ(model->output(0).get_partial_shape(), model_ref->output(0).get_partial_shape());
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
comparator.enable(FunctionsComparator::CmpValues::NAMES);
}

TEST_F(TransformationTestsF, ConvertSqueeze15ToSqueeze1_unsupported_skip) {
manager.register_pass<ov::pass::ConvertSqueeze15ToSqueeze0>();
model = create_v15_model(IndicesMode::PARAM, {}, true);
}

0 comments on commit 86083e0

Please sign in to comment.