Skip to content

Commit

Permalink
PullReshapeThroughDequantization rollback
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Mar 22, 2023
1 parent 6a066c6 commit 0112013
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,11 @@ ngraph::pass::low_precision::PullReshapeThroughDequantization::PullReshapeThroug
const auto& opsMap = m.get_pattern_value_map();
auto reshape = opsMap.at(reshapeWrapper).get_node_shared_ptr();

auto child = reshape->get_output_target_inputs(0).begin()->get_node();
if (ov::is_type<opset1::GroupConvolution>(child)) {
return false;
}

while (reshape != nullptr) {
const auto parent = reshape->get_input_node_shared_ptr(0);
if (ov::is_type<opset1::Multiply>(parent) || ov::is_type<opset1::Subtract>(parent)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ const std::vector<ngraph::Shape> inputShapes = {
};

const std::vector<std::pair<ngraph::Shape, ngraph::Shape>> dequantizationOnWeightElementwiseConstantShapes = {
{ngraph::Shape({1, 960}), ngraph::Shape({960, 1, 1, 1, 1})},
{ngraph::Shape({9, 960}), ngraph::Shape({960, 1, 1, 3, 3})}
{ngraph::Shape({1, 960}), ngraph::Shape({960, 1, 1, 1})},
{ngraph::Shape({9, 960}), ngraph::Shape({960, 1, 3, 3})}
};

const std::vector<ngraph::Shape> multiplyShapes = {ngraph::Shape({1, 1, 960, 1})};
Expand Down Expand Up @@ -230,7 +230,7 @@ const std::vector<PullReshapeThroughDequantizationTestValues> testValues = {
{ {127.f}, element::f32, {}, false, 1ul, element::u8, true },
{ {0.02f}, element::f32, {}, false }
},
{ std::vector<float>{ 2.f }, ngraph::element::i8, {960, 1, 1, 3, 3}},
{ std::vector<float>{ 2.f }, ngraph::element::i8, {960, 1, 3, 3}},
{
{ ngraph::element::f32, false },
{ {127.f}, element::f32, {/* from parameter */}, false },
Expand All @@ -239,7 +239,7 @@ const std::vector<PullReshapeThroughDequantizationTestValues> testValues = {
{},
{},
{},
{},
{{960, 1, 1, 3, 3}},
ngraph::element::f32,
{}
}
Expand Down Expand Up @@ -327,7 +327,7 @@ const std::vector<PullReshapeThroughDequantizationTestValues> testValues = {
{ {127.f}, element::f32, {}, false, 1ul, element::u8, true },
{ {0.02f}, element::f32, {}, false }
},
{ std::vector<float>{ 2.f }, ngraph::element::i8, {960, 1, 1, 3, 3}},
{ std::vector<float>{ 2.f }, ngraph::element::i8, {960, 1, 3, 3}},
{
{ ngraph::element::f32, false },
{ {127.f}, element::f32, {/* from parameter */}, false, 1ul, element::i8, true },
Expand All @@ -336,7 +336,7 @@ const std::vector<PullReshapeThroughDequantizationTestValues> testValues = {
{},
{},
{},
{},
{{960, 1, 1, 3, 3}},
ngraph::element::f32,
{}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ const std::vector<ngraph::Shape> inputShapes = {
};

const std::vector<std::pair<ngraph::Shape, ngraph::Shape>> dequantizationOnWeightElementwiseConstantShapes = {
{ngraph::Shape({}), ngraph::Shape({1, 1, 1, 1, 1})},
{ngraph::Shape({1}), ngraph::Shape({1, 1, 1, 1, 1})}};
{ngraph::Shape({}), ngraph::Shape({1, 1, 1, 1})},
{ngraph::Shape({1}), ngraph::Shape({1, 1, 1, 1})}};

const std::vector<PullTransposeThroughDequantizationTestValues> testValues = {
// Actual:
Expand Down Expand Up @@ -214,7 +214,7 @@ const std::vector<PullTransposeThroughDequantizationTestValues> testValues = {
{{127.f}, element::f32, {}, false, 1ul, element::u8, true},
{{0.02f}, element::f32, {}, false}
},
{std::vector<float>{2.f}, ngraph::element::i8, {960, 1, 1, 3, 3}},
{std::vector<float>{2.f}, ngraph::element::i8, {960, 1, 3, 3}},
{
{ngraph::element::f32, false},
{{127.f}, element::f32, {/* from parameter */}, false},
Expand All @@ -223,7 +223,7 @@ const std::vector<PullTransposeThroughDequantizationTestValues> testValues = {
{},
{},
{},
{},
{{960, 1, 1, 3, 3}},
ngraph::element::f32,
{}
}
Expand Down

0 comments on commit 0112013

Please sign in to comment.