forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Opset13][FP8] Introduce FakeConvert op core (openvinotoolkit#20930)
* FakeConvert op init * Update dest types names * Update op hpp * Update opset ops number * Init type_prop tests * Add attributes tests * Add op check test * Update namespace in fc cpp * Update getters * Refactor static member * Make destination_type lower case * Update type in test * Move get_valid_types out of class * Update ops number in opset * Remove apply_scale attribute * Additional constructor to make `shift` input optional
- Loading branch information
Showing
8 changed files
with
226 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
// Copyright (C) 2018-2023 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "openvino/op/op.hpp" | ||
|
||
namespace ov { | ||
namespace op { | ||
namespace v13 { | ||
/// \ingroup ov_ops_cpp_api | ||
class OPENVINO_API FakeConvert : public Op { | ||
public: | ||
OPENVINO_OP("FakeConvert", "opset13"); | ||
|
||
FakeConvert() = default; | ||
FakeConvert(const ov::Output<ov::Node>& arg, | ||
const ov::Output<ov::Node>& scale, | ||
std::string destination_type = "f8e4m3"); | ||
|
||
FakeConvert(const ov::Output<ov::Node>& arg, | ||
const ov::Output<ov::Node>& scale, | ||
const ov::Output<ov::Node>& shift, | ||
std::string destination_type = "f8e4m3"); | ||
|
||
void validate_and_infer_types() override; | ||
std::shared_ptr<ov::Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override; | ||
bool visit_attributes(ov::AttributeVisitor& visitor) override; | ||
bool has_evaluate() const override; | ||
|
||
const std::string& get_destination_type() const; | ||
|
||
private: | ||
void validate_type() const; | ||
|
||
std::string m_destination_type = "f8e4m3"; | ||
}; | ||
} // namespace v13 | ||
} // namespace op | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
// Copyright (C) 2018-2022 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "openvino/op/fake_convert.hpp" | ||
|
||
#include "itt.hpp" | ||
|
||
namespace ov { | ||
namespace op { | ||
namespace v13 { | ||
namespace fake_convert { | ||
static const std::vector<std::string>& get_valid_types() { | ||
static const std::vector<std::string> valid_types{"f8e4m3", "f8e5m2"}; | ||
return valid_types; | ||
} | ||
} // namespace fake_convert | ||
|
||
FakeConvert::FakeConvert(const ov::Output<ov::Node>& arg, | ||
const ov::Output<ov::Node>& scale, | ||
std::string destination_type) | ||
: Op({arg, scale}), | ||
m_destination_type(std::move(destination_type)) { | ||
constructor_validate_and_infer_types(); | ||
} | ||
|
||
FakeConvert::FakeConvert(const ov::Output<ov::Node>& arg, | ||
const ov::Output<ov::Node>& scale, | ||
const ov::Output<ov::Node>& shift, | ||
std::string destination_type) | ||
: Op({arg, scale, shift}), | ||
m_destination_type(std::move(destination_type)) { | ||
constructor_validate_and_infer_types(); | ||
} | ||
|
||
const std::string& FakeConvert::get_destination_type() const { | ||
return m_destination_type; | ||
} | ||
|
||
void FakeConvert::validate_and_infer_types() { | ||
OV_OP_SCOPE(v13_FakeConvert_validate_and_infer_types); | ||
validate_type(); | ||
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0)); | ||
} | ||
|
||
std::shared_ptr<ov::Node> FakeConvert::clone_with_new_inputs(const ov::OutputVector& new_args) const { | ||
OV_OP_SCOPE(v13_FakeConvert_clone_with_new_inputs); | ||
if (new_args.size() == 2) { | ||
return std::make_shared<FakeConvert>(new_args.at(0), new_args.at(1), m_destination_type); | ||
} else if (new_args.size() == 3) { | ||
return std::make_shared<FakeConvert>(new_args.at(0), new_args.at(1), new_args.at(2), m_destination_type); | ||
} else { | ||
OPENVINO_THROW("Incorrect number of FakeConvert new arguments."); | ||
} | ||
} | ||
|
||
bool FakeConvert::visit_attributes(ov::AttributeVisitor& visitor) { | ||
OV_OP_SCOPE(v13_FakeConvert_visit_attributes); | ||
visitor.on_attribute("destination_type", m_destination_type); | ||
return true; | ||
} | ||
|
||
void FakeConvert::validate_type() const { | ||
const auto& valid_types = fake_convert::get_valid_types(); | ||
OPENVINO_ASSERT(std::find(valid_types.begin(), valid_types.end(), m_destination_type) != valid_types.end(), | ||
"Bad format for f8 conversion type: " + m_destination_type); | ||
} | ||
|
||
bool FakeConvert::has_evaluate() const { | ||
return false; | ||
} | ||
|
||
} // namespace v13 | ||
} // namespace op | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
// Copyright (C) 2018-2023 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "openvino/op/fake_convert.hpp" | ||
|
||
#include <gtest/gtest.h> | ||
|
||
#include "common_test_utils/type_prop.hpp" | ||
|
||
using namespace ov; | ||
using ov::op::v0::Parameter; | ||
|
||
TEST(type_prop, fake_convert_no_shift) { | ||
const auto data = std::make_shared<Parameter>(element::f32, PartialShape{2, 3, 8, 6}); | ||
const auto scale = std::make_shared<Parameter>(element::f32, PartialShape{}); | ||
|
||
const auto op = std::make_shared<op::v13::FakeConvert>(data, scale); | ||
EXPECT_EQ(op->get_output_element_type(0), element::f32); | ||
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 3, 8, 6})); | ||
} | ||
|
||
TEST(type_prop, fake_convert_basic_f32) { | ||
const auto data = std::make_shared<Parameter>(element::f32, PartialShape{2, 3, 8, 6}); | ||
const auto scale = std::make_shared<Parameter>(element::f32, PartialShape{}); | ||
const auto shift = std::make_shared<Parameter>(element::f32, PartialShape{}); | ||
|
||
const auto op = std::make_shared<op::v13::FakeConvert>(data, scale, shift); | ||
EXPECT_EQ(op->get_output_element_type(0), element::f32); | ||
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 3, 8, 6})); | ||
} | ||
|
||
TEST(type_prop, fake_convert_basic_f16) { | ||
const auto data = std::make_shared<Parameter>(element::f16, PartialShape{2, 3, 8, 6}); | ||
const auto scale = std::make_shared<Parameter>(element::f16, PartialShape{}); | ||
const auto shift = std::make_shared<Parameter>(element::f16, PartialShape{}); | ||
|
||
const auto op = std::make_shared<op::v13::FakeConvert>(data, scale, shift); | ||
EXPECT_EQ(op->get_output_element_type(0), element::f16); | ||
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 3, 8, 6})); | ||
} | ||
|
||
TEST(type_prop, fake_convert_dynamic_shape) { | ||
const auto data = std::make_shared<Parameter>(element::f32, PartialShape::dynamic()); | ||
const auto scale = std::make_shared<Parameter>(element::f32, PartialShape{}); | ||
const auto shift = std::make_shared<Parameter>(element::f32, PartialShape{}); | ||
|
||
const auto op = std::make_shared<op::v13::FakeConvert>(data, scale, shift); | ||
EXPECT_EQ(op->get_output_element_type(0), element::f32); | ||
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape::dynamic())); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
// Copyright (C) 2018-2023 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "openvino/op/fake_convert.hpp" | ||
|
||
#include <gtest/gtest.h> | ||
|
||
#include "visitors/visitors.hpp" | ||
|
||
using ov::Shape; | ||
using ov::op::v0::Parameter; | ||
using ov::test::NodeBuilder; | ||
|
||
TEST(attributes, fake_convert_v13_attributes_default) { | ||
using ov::op::v13::FakeConvert; | ||
NodeBuilder::get_ops().register_factory<FakeConvert>(); | ||
const auto data = std::make_shared<Parameter>(ov::element::f32, ov::PartialShape{2, 3, 8, 6}); | ||
const auto scale = std::make_shared<Parameter>(ov::element::f32, ov::PartialShape{}); | ||
const auto shift = std::make_shared<Parameter>(ov::element::f32, ov::PartialShape{}); | ||
|
||
const auto op = std::make_shared<FakeConvert>(data, scale, shift); | ||
|
||
NodeBuilder builder(op, {data, scale, shift}); | ||
auto g_op = ov::as_type_ptr<FakeConvert>(builder.create()); | ||
|
||
EXPECT_EQ(g_op->get_destination_type(), op->get_destination_type()); | ||
EXPECT_EQ(g_op->get_output_element_type(0), op->get_output_element_type(0)); | ||
EXPECT_EQ(g_op->get_output_partial_shape(0), op->get_output_partial_shape(0)); | ||
} | ||
|
||
TEST(attributes, fake_convert_v13_attributes_custom) { | ||
using ov::op::v13::FakeConvert; | ||
NodeBuilder::get_ops().register_factory<FakeConvert>(); | ||
const auto data = std::make_shared<Parameter>(ov::element::f32, ov::PartialShape{2, 3, 8, 6}); | ||
const auto scale = std::make_shared<Parameter>(ov::element::f32, ov::PartialShape{}); | ||
const auto shift = std::make_shared<Parameter>(ov::element::f32, ov::PartialShape{}); | ||
|
||
const auto op = std::make_shared<FakeConvert>(data, scale, shift, "f8e5m2"); | ||
|
||
NodeBuilder builder(op, {data, scale, shift}); | ||
auto g_op = ov::as_type_ptr<FakeConvert>(builder.create()); | ||
|
||
EXPECT_EQ(g_op->get_destination_type(), op->get_destination_type()); | ||
EXPECT_EQ(g_op->get_output_element_type(0), op->get_output_element_type(0)); | ||
EXPECT_EQ(g_op->get_output_partial_shape(0), op->get_output_partial_shape(0)); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters