Skip to content

Commit

Permalink
[PT FE] Support any float type for batch norm (#23750)
Browse files Browse the repository at this point in the history
### Details:
 - *Support any float type for batch norm*
- *Tests for fp16 sporadically fail by accuracy and fp64 is not
supported by torch, will not update tests this time.*

### Tickets:
 - *#23538*
  • Loading branch information
mvafin authored Apr 2, 2024
1 parent e2c6ae9 commit a0a80e0
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion src/frontends/pytorch/src/op/batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ Output<Node> broadcast_const_to_channel_dim(const NodeContext& context,
auto one_i = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
auto channel_dim = context.mark_node(std::make_shared<v8::Gather>(input_shape, one_i, zero_i));
auto channel_dim_exp = context.mark_node(std::make_shared<v0::Unsqueeze>(channel_dim, zero_i));
return context.mark_node(std::make_shared<v3::Broadcast>(value, channel_dim_exp));
auto value_ = context.mark_node(std::make_shared<v1::ConvertLike>(value, input));
return context.mark_node(std::make_shared<v3::Broadcast>(value_, channel_dim_exp));
}

OutputVector make_batch_norm(const NodeContext& context,
Expand All @@ -53,10 +54,14 @@ OutputVector make_batch_norm(const NodeContext& context,
if (!w.get_node_shared_ptr()) {
auto one_f = context.mark_node(v0::Constant::create(element::f32, Shape{}, {1}));
w = broadcast_const_to_channel_dim(context, input, one_f);
} else {
w = context.mark_node(std::make_shared<v1::ConvertLike>(w, input));
}
if (!b.get_node_shared_ptr()) {
auto zero_f = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0}));
b = broadcast_const_to_channel_dim(context, input, zero_f);
} else {
b = context.mark_node(std::make_shared<v1::ConvertLike>(b, input));
}
auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
auto zero_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
Expand All @@ -67,12 +72,16 @@ OutputVector make_batch_norm(const NodeContext& context,
auto axes = context.mark_node(std::make_shared<v0::Concat>(OutputVector{zero_1d, after_channel_dims}, 0));
if (!mean.get_node_shared_ptr()) {
mean = context.mark_node(std::make_shared<v1::ReduceMean>(input, axes, false));
} else {
mean = context.mark_node(std::make_shared<v1::ConvertLike>(mean, input));
}
if (!var.get_node_shared_ptr()) {
auto current_mean = context.mark_node(std::make_shared<v1::ReduceMean>(input, axes, true));
auto sub_v = context.mark_node(std::make_shared<v1::Subtract>(input, current_mean));
auto sqr_sub = context.mark_node(std::make_shared<v1::Multiply>(sub_v, sub_v));
var = context.mark_node(std::make_shared<v1::ReduceMean>(sqr_sub, axes, false));
} else {
var = context.mark_node(std::make_shared<v1::ConvertLike>(var, input));
}
return {context.mark_node(std::make_shared<v5::BatchNormInference>(input, w, b, mean, var, epsilon))};
}
Expand Down

0 comments on commit a0a80e0

Please sign in to comment.