Skip to content

Commit

Permalink
SnippetsFunctionBase usage
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Mar 21, 2023
1 parent a12dc7c commit cffba04
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 109 deletions.
45 changes: 3 additions & 42 deletions src/common/snippets/tests/src/pass/precision_propagation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,45 +17,6 @@ namespace snippets {

namespace {

/**
* @class DummyAdd
* @brief DummyAdd operation has custom validate_and_infer_types method implementation.
*/
class DummyAdd : public ngraph::opset1::Add {
public:
OPENVINO_OP("DummyAdd", "test::snippets");

DummyAdd(const Output<Node>& arg0,
const Output<Node>& arg1,
const ngraph::op::AutoBroadcastSpec& auto_broadcast =
ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY))
: ngraph::opset1::Add(arg0, arg1, auto_broadcast) {
constructor_validate_and_infer_types();
}

DummyAdd(const ngraph::opset1::Add& add)
: Add(add.get_input_source_output(0), add.get_input_source_output(1), add.get_autob()) {
constructor_validate_and_infer_types();
}

DummyAdd() = default;

void validate_and_infer_types() override {
const auto input_type1 = get_input_element_type(0);
const auto input_type2 = get_input_element_type(1);

const element::Type output_type = (input_type1 == element::i8) || (input_type2 == element::i8) ?
element::i32 :
get_input_element_type(0);

set_output_type(0, output_type, get_input_partial_shape(0));
}

std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override {
return std::make_shared<DummyAdd>(new_args.at(0), new_args.at(1), this->get_autob());
}
};

class DummyPrecisionPropagationTargetMachine : public DummyTargetMachine {
public:
DummyPrecisionPropagationTargetMachine(
Expand Down Expand Up @@ -114,7 +75,7 @@ TEST_P(PrecisionPropagationTest, CompareFunctions) {
const auto test_values = std::get<1>(param);

const auto input_shapes = std::vector<PartialShape>({ shapes.first, shapes.second });
EmpyInfrastructureSupportStubForPrecisionPropagationFunction<DummyAdd> infrastructureSupportStub(
PrecisionPropagationAddFunction function_stub(
input_shapes,
test_values.input_types[0],
test_values.input_types[1],
Expand All @@ -130,15 +91,15 @@ TEST_P(PrecisionPropagationTest, CompareFunctions) {
test_values.expected.convertion_before_op2_2,
test_values.expected.convertion_after_op2
});
function = infrastructureSupportStub.getOriginal();
function = function_stub.getOriginal();

const auto target_machine = std::make_shared<DummyPrecisionPropagationTargetMachine>(
test_values.actual.op1_supported_precisions,
test_values.actual.op2_supported_precisions);

manager.register_pass<ngraph::snippets::pass::PropagatePrecision>(target_machine);

function_ref = infrastructureSupportStub.getReference();
function_ref = function_stub.getReference();
}

namespace PrecisionPropagationTestInstantiation {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,94 @@ namespace ov {
namespace test {
namespace snippets {

class PrecisionPropagationFunction {
/**
* @class DummyAdd
* @brief DummyAdd operation has custom validate_and_infer_types method implementation.
*/
class DummyAdd : public ngraph::opset1::Add {
public:
template<typename T>
OPENVINO_OP("DummyAdd", "test::snippets");

DummyAdd(const Output<Node>& arg0,
const Output<Node>& arg1,
const ngraph::op::AutoBroadcastSpec& auto_broadcast =
ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY))
: ngraph::opset1::Add(arg0, arg1, auto_broadcast) {
constructor_validate_and_infer_types();
}

DummyAdd(const ngraph::opset1::Add& add)
: Add(add.get_input_source_output(0), add.get_input_source_output(1), add.get_autob()) {
constructor_validate_and_infer_types();
}

DummyAdd() = default;

void validate_and_infer_types() override {
const auto input_type1 = get_input_element_type(0);
const auto input_type2 = get_input_element_type(1);

const element::Type output_type = (input_type1 == element::i8) || (input_type2 == element::i8) ?
element::i32 :
get_input_element_type(0);

set_output_type(0, output_type, get_input_partial_shape(0));
}

std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override {
return std::make_shared<DummyAdd>(new_args.at(0), new_args.at(1), this->get_autob());
}
};

class PrecisionPropagationAddFunctionParams {
public:
class Actual {
public:
std::pair<element::Type, element::Type> convertion_before_op1;
element::Type convertion_before_op2_1;
std::pair<element::Type, element::Type> convertion_before_op2_2;
};

class Expected {
public:
std::pair<element::Type, element::Type> convertion_before_op1;
element::Type convertion_before_op2_1;
std::pair<element::Type, element::Type> convertion_before_op2_2;
element::Type convertion_after_op2;
};
};

/**
* @class PrecisionPropagationAddFunction
* @brief PrecisionPropagationAddFunction instance returns reference and original functions.
*
* Input arguments are used to create function in getOriginal or getReference methods only.
* Dont use getLowered method, it is not implemented and throw std::runtime_error exception.
* Note, ov::element::Type_t precision base type input argument is not used.
*/
class PrecisionPropagationAddFunction : public SnippetsFunctionBase {
public:
explicit PrecisionPropagationAddFunction(
const std::vector<PartialShape> input_shapes,
const ngraph::element::Type precision1,
const ngraph::element::Type precision2,
const ngraph::element::Type constant_precision,
PrecisionPropagationAddFunctionParams::Actual actual,
PrecisionPropagationAddFunctionParams::Expected expected) :
SnippetsFunctionBase(input_shapes),
precision1(precision1),
precision2(precision2),
constant_precision(constant_precision),
actual(actual),
expected(expected) {
OPENVINO_ASSERT(input_shapes.size() == 2ull, "input_shapes size has to be equal to 2");
}

/*
* Don't call this method explicity. You should create the instance of PrecisionPropagationAddFunction before.
* After the method will be called implicitly in getOriginal or getReference methods.
* Note, please, getLowered method is not implemented.
*/
static std::shared_ptr<ngraph::Function> get(
const ngraph::element::Type precision1,
const ngraph::PartialShape& inputShape1,
Expand All @@ -26,10 +111,29 @@ class PrecisionPropagationFunction {
const element::Type convertion_before_op2_1 = element::undefined,
const std::pair<element::Type, element::Type>& convertion_before_op2_2 = std::pair<element::Type, element::Type>(),
const element::Type convertion_after_op2 = {}) {
const auto create_convert = [](std::shared_ptr<Node> parent, const element::Type convertion_type) -> std::shared_ptr<Node> {
return convertion_type == element::undefined
? std::dynamic_pointer_cast<Node>(parent)
: std::make_shared<ngraph::snippets::op::ConvertSaturation>(parent, convertion_type);
};

const auto make_branch = [&create_convert](
const ngraph::element::Type precision,
const ngraph::PartialShape& inputShape,
const size_t index,
const element::Type convertion_type) -> std::pair<std::shared_ptr<ngraph::opset1::Parameter>, std::shared_ptr<ov::Node>> {
const auto parameter = std::make_shared<ngraph::opset1::Parameter>(precision, inputShape);
parameter->set_friendly_name("parameter" + std::to_string(index));

std::shared_ptr<Node> parent = create_convert(parameter, convertion_type);

return { parameter, parent };
};

const auto branch1 = make_branch(precision1, inputShape1, 1, convertion_before_op1.first);
const auto branch2 = make_branch(precision2, inputShape2, 2, convertion_before_op1.second);

std::shared_ptr<Node> parent = std::make_shared<T>(branch1.second, branch2.second);
std::shared_ptr<Node> parent = std::make_shared<DummyAdd>(branch1.second, branch2.second);
parent->set_friendly_name("add");

parent = create_convert(parent, convertion_before_op2_1);
Expand Down Expand Up @@ -57,73 +161,15 @@ class PrecisionPropagationFunction {
result_out_tensor.set_names({ "result_tensor" });
result->set_friendly_name("result");

const ngraph::ResultVector results{result};
const ngraph::ParameterVector parameters{branch1.first, branch2.first};
const ngraph::ResultVector results{ result };
const ngraph::ParameterVector parameters{ branch1.first, branch2.first };
const auto model = std::make_shared<ngraph::Function>(results, parameters, "SnippetsPrecisionPropagation");
return model;
}

private:
static std::shared_ptr<Node> create_convert(std::shared_ptr<Node> parent, const element::Type convertion_type) {
return convertion_type == element::undefined
? std::dynamic_pointer_cast<Node>(parent)
: std::make_shared<ngraph::snippets::op::ConvertSaturation>(parent, convertion_type);
}

static std::pair<std::shared_ptr<ngraph::opset1::Parameter>, std::shared_ptr<ov::Node>> make_branch(
const ngraph::element::Type precision,
const ngraph::PartialShape& inputShape,
const size_t index,
const element::Type convertion_type) {
const auto parameter = std::make_shared<ngraph::opset1::Parameter>(precision, inputShape);
parameter->set_friendly_name("parameter" + std::to_string(index));

std::shared_ptr<Node> parent = create_convert(parameter, convertion_type);

return {parameter, parent};
}
};

class PrecisionPropagationFunctionParams {
public:
class Actual {
public:
std::pair<element::Type, element::Type> convertion_before_op1;
element::Type convertion_before_op2_1;
std::pair<element::Type, element::Type> convertion_before_op2_2;
};

class Expected {
public:
std::pair<element::Type, element::Type> convertion_before_op1;
element::Type convertion_before_op2_1;
std::pair<element::Type, element::Type> convertion_before_op2_2;
element::Type convertion_after_op2;
};
};

template<typename T>
class EmpyInfrastructureSupportStubForPrecisionPropagationFunction : public SnippetsFunctionBase {
public:
explicit EmpyInfrastructureSupportStubForPrecisionPropagationFunction(
const std::vector<PartialShape> input_shapes,
const ngraph::element::Type precision1,
const ngraph::element::Type precision2,
const ngraph::element::Type constant_precision,
PrecisionPropagationFunctionParams::Actual actual,
PrecisionPropagationFunctionParams::Expected expected) :
SnippetsFunctionBase(input_shapes),
precision1(precision1),
precision2(precision2),
constant_precision(constant_precision),
actual(actual),
expected(expected) {
OPENVINO_ASSERT(input_shapes.size() == 2ull, "input_shapes size has to be equal to 2");
}

protected:
std::shared_ptr<ov::Model> initOriginal() const override {
return PrecisionPropagationFunction::get<T>(
return get(
precision1,
input_shapes[0],
precision2,
Expand All @@ -135,7 +181,7 @@ class EmpyInfrastructureSupportStubForPrecisionPropagationFunction : public Snip
}

std::shared_ptr<ov::Model> initReference() const override {
return PrecisionPropagationFunction::get<T>(
return get(
precision1,
input_shapes[0],
precision2,
Expand All @@ -150,8 +196,8 @@ class EmpyInfrastructureSupportStubForPrecisionPropagationFunction : public Snip
const ngraph::element::Type precision1;
const ngraph::element::Type precision2;
const ngraph::element::Type constant_precision;
const PrecisionPropagationFunctionParams::Actual actual;
const PrecisionPropagationFunctionParams::Expected expected;
const PrecisionPropagationAddFunctionParams::Actual actual;
const PrecisionPropagationAddFunctionParams::Expected expected;
};

} // namespace snippets
Expand Down

0 comments on commit cffba04

Please sign in to comment.