Skip to content

Commit

Permalink
Support dyn shapes in BatchNormDecomposition transformation (openvino…
Browse files Browse the repository at this point in the history
…toolkit#23290)

### Details:
Support dyn shapes in BatchNormDecomposition transformation

### Tickets:
 - *CVS-133609*
  • Loading branch information
itikhono authored Mar 6, 2024
1 parent 2c0efbb commit aad89fb
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,16 @@ using namespace ov;

ov::pass::BatchNormDecomposition::BatchNormDecomposition() {
MATCHER_SCOPE(BatchNormDecomposition);
auto bn_1 = pattern::wrap_type<ov::op::v0::BatchNormInference>({pattern::any_input(pattern::has_static_shape()),
pattern::any_input(pattern::has_static_shape()),
auto bn_1 = pattern::wrap_type<ov::op::v0::BatchNormInference>({pattern::any_input(),
pattern::any_input(),
pattern::any_input(pattern::has_static_rank()),
pattern::any_input(pattern::has_static_shape()),
pattern::any_input(pattern::has_static_shape())});
pattern::any_input(),
pattern::any_input()});
auto bn_5 = pattern::wrap_type<ov::op::v5::BatchNormInference>({pattern::any_input(pattern::has_static_rank()),
pattern::any_input(pattern::has_static_shape()),
pattern::any_input(pattern::has_static_shape()),
pattern::any_input(pattern::has_static_shape()),
pattern::any_input(pattern::has_static_shape())});
pattern::any_input(),
pattern::any_input(),
pattern::any_input(),
pattern::any_input()});
auto bn = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{bn_1, bn_5});

matcher_pass_callback callback = [this](ov::pass::pattern::Matcher& m) {
Expand Down Expand Up @@ -83,9 +83,8 @@ ov::pass::BatchNormDecomposition::BatchNormDecomposition() {
std::make_shared<ov::op::v1::Reshape>(gamma_div_scale, new_shape, true);
std::shared_ptr<Node> beta_aligned = std::make_shared<ov::op::v1::Reshape>(m_beta, new_shape, true);
std::shared_ptr<Node> mean_aligned = std::make_shared<ov::op::v1::Reshape>(m_mean, new_shape, true);
std::shared_ptr<Node> mean_negative = std::make_shared<ov::op::v1::Multiply>(
mean_aligned,
ov::op::v0::Constant::create(mean_aligned->get_output_element_type(0), Shape{}, {-1}));
auto mul_const = ov::op::v0::Constant::create(mean_aligned->get_output_element_type(0), Shape{}, {-1});
std::shared_ptr<Node> mean_negative = std::make_shared<ov::op::v1::Multiply>(mean_aligned, mul_const);

if (auto constant = ov::util::get_constant_from_source(beta_aligned))
beta_aligned = constant;
Expand All @@ -103,9 +102,23 @@ ov::pass::BatchNormDecomposition::BatchNormDecomposition() {

add->set_friendly_name(m_bn->get_friendly_name());

copy_runtime_info(
m_bn,
{scale_add, scale, gamma_div_scale, gamma_div_scale_aligned, beta_aligned, input_sub_mean, mul, add});
copy_runtime_info(m_bn,
{scale_add,
scale,
gamma_div_scale,
gamma_div_scale_aligned,
beta_aligned,
input_sub_mean,
mul,
add,
mean_negative,
mean_aligned,
new_shape,
tail_shape,
tail_shape_rank,
one,
mul_const,
C_dim});

replace_node(m_bn, add);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,45 @@
using namespace ov;
using namespace testing;

std::shared_ptr<ov::Model> get_ref_model_with_dyn_shapes(ov::element::Type precision, const PartialShape& input_shape) {
auto input = std::make_shared<opset1::Parameter>(precision, input_shape);
auto gamma = std::make_shared<opset1::Parameter>(precision, PartialShape{-1});
auto beta = std::make_shared<opset1::Parameter>(precision, PartialShape{-1});
auto mean = std::make_shared<opset1::Parameter>(precision, PartialShape{-1});
auto var = std::make_shared<opset1::Parameter>(precision, PartialShape{-1});
// scale_add = variance + eps
auto scale_add = std::make_shared<ov::op::v1::Add>(var, ov::op::v0::Constant::create(precision, Shape{}, {0.001}));
// scale = sqrt(variance + eps)
auto scale = std::make_shared<ov::op::v0::Sqrt>(scale_add);
// Divide `gamma` by `sqrt(variance + eps)`
auto gamma_div_scale = std::make_shared<ov::op::v1::Divide>(gamma, scale);

int64_t dims_to_add = input->get_partial_shape().rank().get_length() - 2;
const auto one = ov::op::v0::Constant::create(element::i64, Shape{1}, {1});
const auto tail_shape_rank = ov::op::v0::Constant::create(element::i64, Shape{1}, {dims_to_add});
const auto tail_shape = std::make_shared<ov::op::v3::Broadcast>(one, tail_shape_rank);
const auto C_dim = std::make_shared<ov::op::v3::ShapeOf>(gamma);
// create new shape [1, C, 1, 1, ...]
const auto new_shape = std::make_shared<ov::op::v0::Concat>(OutputVector{one, C_dim, tail_shape}, 0);

std::shared_ptr<Node> gamma_div_scale_aligned =
std::make_shared<ov::op::v1::Reshape>(gamma_div_scale, new_shape, true);
std::shared_ptr<Node> beta_aligned = std::make_shared<ov::op::v1::Reshape>(beta, new_shape, true);
std::shared_ptr<Node> mean_aligned = std::make_shared<ov::op::v1::Reshape>(mean, new_shape, true);
std::shared_ptr<Node> mean_negative = std::make_shared<ov::op::v1::Multiply>(
mean_aligned,
ov::op::v0::Constant::create(mean_aligned->get_output_element_type(0), Shape{}, {-1}));

// input_sub_mean = input + mean * -1
auto input_sub_mean = std::make_shared<ov::op::v1::Add>(input, mean_negative);
// Multiply `input - mean` and `gamma / sqrt(variance + eps)`
auto mul = std::make_shared<ov::op::v1::Multiply>(input_sub_mean, gamma_div_scale_aligned);
// Add `(input - mean) * gamma / sqrt(variance + eps)` and `beta`
auto add = std::make_shared<ov::op::v1::Add>(mul, beta_aligned);

return std::make_shared<ov::Model>(NodeVector{add}, ParameterVector{input, gamma, beta, mean, var});
}

TEST_F(TransformationTestsF, BatchNormDecompositionStaticRankOpset1) {
const PartialShape input_shape{-1, -1, -1, -1};
const auto precision = element::f32;
Expand Down Expand Up @@ -74,6 +113,42 @@ TEST_F(TransformationTestsF, BatchNormDecompositionStaticRankOpset5) {
}
}

TEST_F(TransformationTestsF, BatchNormDecompositionDynamicShapesOpset1) {
const PartialShape input_shape{-1, -1, -1, -1};
const auto precision = element::f32;
{
auto input = std::make_shared<opset1::Parameter>(precision, input_shape);
auto gamma = std::make_shared<opset1::Parameter>(precision, PartialShape{-1});
auto beta = std::make_shared<opset1::Parameter>(precision, PartialShape{-1});
auto mean = std::make_shared<opset1::Parameter>(precision, PartialShape{-1});
auto var = std::make_shared<opset1::Parameter>(precision, PartialShape{-1});
auto batch_norm = std::make_shared<opset1::BatchNormInference>(input, gamma, beta, mean, var, 0.001);

model = std::make_shared<ov::Model>(NodeVector{batch_norm}, ParameterVector{input, gamma, beta, mean, var});
manager.register_pass<ov::pass::BatchNormDecomposition>();
comparator.enable(FunctionsComparator::CONST_VALUES);
}
{ model_ref = get_ref_model_with_dyn_shapes(precision, input_shape); }
}

TEST_F(TransformationTestsF, BatchNormDecompositionDynamicShapesOpset5) {
const PartialShape input_shape{-1, -1, -1, -1};
const auto precision = element::f32;
{
auto input = std::make_shared<opset1::Parameter>(precision, input_shape);
auto gamma = std::make_shared<opset1::Parameter>(precision, PartialShape{-1});
auto beta = std::make_shared<opset1::Parameter>(precision, PartialShape{-1});
auto mean = std::make_shared<opset1::Parameter>(precision, PartialShape{-1});
auto var = std::make_shared<opset1::Parameter>(precision, PartialShape{-1});
auto batch_norm = std::make_shared<opset5::BatchNormInference>(input, gamma, beta, mean, var, 0.001);

model = std::make_shared<ov::Model>(NodeVector{batch_norm}, ParameterVector{input, gamma, beta, mean, var});
manager.register_pass<ov::pass::BatchNormDecomposition>();
comparator.enable(FunctionsComparator::CONST_VALUES);
}
{ model_ref = get_ref_model_with_dyn_shapes(precision, input_shape); }
}

TEST_F(TransformationTestsF, BatchNormDecompositionDynamicRank) {
{
auto input = std::make_shared<opset1::Parameter>(element::f32, PartialShape::dynamic());
Expand Down

0 comments on commit aad89fb

Please sign in to comment.