From a0a80e010318dae4e4453ba2fbdb71280faec81e Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Tue, 2 Apr 2024 10:04:56 +0200 Subject: [PATCH] [PT FE] Support any float type for batch norm (#23750) ### Details: - *Support any float type for batch norm* - *Tests for fp16 sporadically fail by accuracy and fp64 is not supported by torch, will not update tests this time.* ### Tickets: - *#23538* --- src/frontends/pytorch/src/op/batch_norm.cpp | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/frontends/pytorch/src/op/batch_norm.cpp b/src/frontends/pytorch/src/op/batch_norm.cpp index 4964eb37bf05ea..7f31d0894e4af6 100644 --- a/src/frontends/pytorch/src/op/batch_norm.cpp +++ b/src/frontends/pytorch/src/op/batch_norm.cpp @@ -36,7 +36,8 @@ Output broadcast_const_to_channel_dim(const NodeContext& context, auto one_i = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1})); auto channel_dim = context.mark_node(std::make_shared(input_shape, one_i, zero_i)); auto channel_dim_exp = context.mark_node(std::make_shared(channel_dim, zero_i)); - return context.mark_node(std::make_shared(value, channel_dim_exp)); + auto value_ = context.mark_node(std::make_shared(value, input)); + return context.mark_node(std::make_shared(value_, channel_dim_exp)); } OutputVector make_batch_norm(const NodeContext& context, @@ -53,10 +54,14 @@ OutputVector make_batch_norm(const NodeContext& context, if (!w.get_node_shared_ptr()) { auto one_f = context.mark_node(v0::Constant::create(element::f32, Shape{}, {1})); w = broadcast_const_to_channel_dim(context, input, one_f); + } else { + w = context.mark_node(std::make_shared(w, input)); } if (!b.get_node_shared_ptr()) { auto zero_f = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0})); b = broadcast_const_to_channel_dim(context, input, zero_f); + } else { + b = context.mark_node(std::make_shared(b, input)); } auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); auto zero_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0})); @@ -67,12 +72,16 @@ OutputVector make_batch_norm(const NodeContext& context, auto axes = context.mark_node(std::make_shared(OutputVector{zero_1d, after_channel_dims}, 0)); if (!mean.get_node_shared_ptr()) { mean = context.mark_node(std::make_shared(input, axes, false)); + } else { + mean = context.mark_node(std::make_shared(mean, input)); } if (!var.get_node_shared_ptr()) { auto current_mean = context.mark_node(std::make_shared(input, axes, true)); auto sub_v = context.mark_node(std::make_shared(input, current_mean)); auto sqr_sub = context.mark_node(std::make_shared(sub_v, sub_v)); var = context.mark_node(std::make_shared(sqr_sub, axes, false)); + } else { + var = context.mark_node(std::make_shared(var, input)); } return {context.mark_node(std::make_shared(input, w, b, mean, var, epsilon))}; }