Skip to content

Commit

Permalink
[TRANSFORMATIONS] TEST PR#22236 (openvinotoolkit#22280)
Browse files Browse the repository at this point in the history
Co-authored-by: Ivan Tikhonov <[email protected]>
  • Loading branch information
iefode and itikhono authored Jan 22, 2024
1 parent c538d03 commit a046589
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
7 changes: 2 additions & 5 deletions src/core/src/op/squared_difference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,7 @@ bool ov::op::v0::SquaredDifference::evaluate(TensorVector& outputs, const Tensor

bool ov::op::v0::SquaredDifference::has_evaluate() const {
OV_OP_SCOPE(v0_SquaredDifference_has_evaluate);
switch (get_input_element_type(0)) {
case element::f32:
if (get_input_element_type(0) == element::f32)
return true;
default:
return false;
}
return false;
}
18 changes: 18 additions & 0 deletions src/core/tests/pass/constant_folding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3930,3 +3930,21 @@ TEST(constant_folding, parameter_with_unspecified_type_from_host_tensor) {
auto model = std::make_shared<ov::Model>(ov::ResultVector{res}, ov::ParameterVector{param});
EXPECT_NO_THROW(run_constant_folding(model));
}

TEST(constant_folding, sq_diff) {
auto const_0 = std::make_shared<ov::op::v0::Constant>(element::f32, ov::Shape{1}, std::vector<float>{4});
auto const_1 = std::make_shared<ov::op::v0::Constant>(element::f32, ov::Shape{1}, std::vector<float>{2});
auto sq_diff = std::make_shared<ov::op::v0::SquaredDifference>(const_0, const_1);
auto res = std::make_shared<ov::op::v0::Result>(sq_diff);
auto model = std::make_shared<ov::Model>(ov::ResultVector{res}, ov::ParameterVector{});
auto ops = model->get_ops();
ASSERT_GT(ops.size(), 2);
EXPECT_NO_THROW(run_constant_folding(model));
ops = model->get_ordered_ops();
// constant + result
ASSERT_EQ(ops.size(), 2);
auto const_node = std::dynamic_pointer_cast<ov::op::v0::Constant>(ops.front());
ASSERT_NE(const_node, nullptr);
auto res_node = std::dynamic_pointer_cast<ov::op::v0::Result>(ops.back());
ASSERT_NE(res_node, nullptr);
}

0 comments on commit a046589

Please sign in to comment.