Skip to content

Commit

Permalink
upcast randn to fp32
Browse files Browse the repository at this point in the history
  • Loading branch information
pavel-esir committed Nov 30, 2023
1 parent 51b5bc5 commit ed857e9
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
11 changes: 8 additions & 3 deletions src/frontends/onnx/frontend/src/utils/random_normal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "default_opset.hpp"
#include "ngraph/opsets/opset8.hpp"
#include "transformations/rt_info/disable_fp16_compression.hpp"

namespace ngraph {
namespace onnx_import {
Expand All @@ -27,7 +28,7 @@ OutputVector make_random_normal(const Output<ngraph::Node>& shape,
const uint64_t seed_1 = op_seed;
const uint64_t seed_2 = (op_seed == 0 ? op_seed : op_seed + 10000);

const auto min_val = default_opset::Constant::create(target_type, Shape{1}, {0});
const auto min_val = default_opset::Constant::create(target_type, Shape{1}, {std::numeric_limits<float>::min()});
const auto max_val = default_opset::Constant::create(target_type, Shape{1}, {1});

const auto uniform_1 =
Expand All @@ -45,15 +46,19 @@ OutputVector make_random_normal(const Output<ngraph::Node>& shape,
const auto multiply_minus_two_log = std::make_shared<default_opset::Multiply>(log, minus_two);
const auto sqrt = std::make_shared<default_opset::Sqrt>(multiply_minus_two_log);

const auto multiply_two_pi = std::make_shared<default_opset::Multiply>(uniform_2, pi);
const auto multiply_two_pi_uniform_2 = std::make_shared<default_opset::Multiply>(multiply_two_pi, uniform_2);
const auto multiply_pi_uniform2 = std::make_shared<default_opset::Multiply>(uniform_2, pi);
const auto multiply_two_pi_uniform_2 = std::make_shared<default_opset::Multiply>(multiply_pi_uniform2, two);
auto const cos = std::make_shared<default_opset::Cos>(multiply_two_pi_uniform_2);

auto const scale_const = default_opset::Constant::create(target_type, Shape{1}, {scale});
auto const mean_const = default_opset::Constant::create(target_type, Shape{1}, {mean});
auto const product =
std::make_shared<default_opset::Multiply>(scale_const, std::make_shared<default_opset::Multiply>(sqrt, cos));
auto const sum = std::make_shared<default_opset::Add>(product, mean_const);

Check warning on line 58 in src/frontends/onnx/frontend/src/utils/random_normal.cpp

View workflow job for this annotation

GitHub Actions / clang-format

[reviewdog-suggester] reported by reviewdog 🐶 Raw Output: src/frontends/onnx/frontend/src/utils/random_normal.cpp:58:- src/frontends/onnx/frontend/src/utils/random_normal.cpp:58:+
// if we don't disable downcasting then log(float32_min) gives -inf
disable_fp16_compression(uniform_1);
disable_fp16_compression(log);

return {sum};
}
Expand Down
12 changes: 8 additions & 4 deletions src/frontends/pytorch/src/op/rand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "openvino/op/shape_of.hpp"
#include "openvino/op/sqrt.hpp"
#include "pt_framework_node.hpp"
#include "transformations/rt_info/disable_fp16_compression.hpp"
#include "utils.hpp"

namespace ov {
Expand All @@ -39,7 +40,7 @@ OutputVector make_random_normal(const NodeContext& context,
const uint64_t seed_1 = distrib(gen);
const uint64_t seed_2 = distrib(gen);

auto min_val = context.mark_node(v0::Constant::create(target_type, Shape{1}, {0}));
auto min_val = context.mark_node(v0::Constant::create(target_type, Shape{1}, {std::numeric_limits<float>::min()}));
auto max_val = context.mark_node(v0::Constant::create(target_type, Shape{1}, {1}));

auto uniform_1 = context.mark_node(
Expand All @@ -57,14 +58,17 @@ OutputVector make_random_normal(const NodeContext& context,
auto multiply_minus_two_log = context.mark_node(std::make_shared<v1::Multiply>(log, minus_two));
auto sqrt = context.mark_node(std::make_shared<v0::Sqrt>(multiply_minus_two_log));

auto multiply_two_pi = context.mark_node(std::make_shared<v1::Multiply>(uniform_2, pi));
auto multiply_two_pi_uniform_2 = context.mark_node(std::make_shared<v1::Multiply>(multiply_two_pi, uniform_2));
auto multiply_pi_uniform2 = context.mark_node(std::make_shared<v1::Multiply>(uniform_2, pi));
auto multiply_two_pi_uniform_2 = context.mark_node(std::make_shared<v1::Multiply>(multiply_pi_uniform2, two));
auto cos = context.mark_node(std::make_shared<v0::Cos>(multiply_two_pi_uniform_2));

auto sqrt_x_cos = context.mark_node(std::make_shared<v1::Multiply>(sqrt, cos));
auto product = context.mark_node(std::make_shared<v1::Multiply>(scale_const, sqrt_x_cos));
auto sum = context.mark_node(std::make_shared<v1::Add>(product, mean_const));


Check warning on line 68 in src/frontends/pytorch/src/op/rand.cpp

View workflow job for this annotation

GitHub Actions / clang-format

[reviewdog-suggester] reported by reviewdog 🐶 Raw Output: src/frontends/pytorch/src/op/rand.cpp:68:- src/frontends/pytorch/src/op/rand.cpp:68:+
// if we don't disable downcasting then log(float32_min) gives -inf
disable_fp16_compression(uniform_1);
disable_fp16_compression(log);
return {sum};
}
}; // namespace
Expand Down

0 comments on commit ed857e9

Please sign in to comment.