Skip to content

Commit

Permalink
added test case with bias
Browse files Browse the repository at this point in the history
  • Loading branch information
alvoron committed Nov 20, 2024
1 parent 2e9a2a0 commit 1c639ef
Show file tree
Hide file tree
Showing 9 changed files with 158 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,128 @@ auto is_skipped_op(const std::shared_ptr<ov::Node>& op) -> bool {
ov::is_type<ov::op::v0::Parameter>(op) ||
ov::is_type<ov::op::v0::Result>(op);
}

bool canBePerformedAsScaleShift(const std::shared_ptr<const Node> &node, const int channelAxis) {
size_t fusingPort = 0;
size_t numNonConstInputs = 0;
ov::PartialShape dataShape;
for (size_t i = 0; i < node->get_input_size(); i++) {
const auto parent = node->get_input_node_shared_ptr(i);
if (!ov::is_type<ov::op::v0::Constant>(parent)) {
fusingPort = i;
dataShape = node->get_input_partial_shape(i);
// only one non-const parent is allowed
if (++numNonConstInputs != 1)
return false;
} else {
// every const parent must have exactly one child
const auto out = parent->outputs();
const bool has_only_child = (out.size() == 1) && (out[0].get_target_inputs().size() == 1);
if (!has_only_child)
return false;
}
}

const auto isBroadcastableToDataInput = [&]() {
for (size_t i = 0; i < node->get_input_size(); i++) {
if (i == fusingPort)
continue;
const ov::PartialShape weightShape = node->get_input_partial_shape(i);
if (!isPerTensorOrPerChannelBroadcastable(dataShape.get_max_shape(), weightShape.get_max_shape(), channelAxis, true))
return false;
}
return true;
};

// Prelu and MulAdd are still ignored
// isConvertablePowerStatic() is ignored
return (ov::is_type<ov::opset1::Add>(node) ||
ov::is_type<ov::opset1::Multiply>(node) ||
ov::is_type<ov::opset1::Subtract>(node) ||
ov::is_type<ov::opset1::Divide>(node)) &&
isBroadcastableToDataInput();
}

bool SupportsFusingWithConvolution_Simple(const std::shared_ptr<const Node> &node, const int channelAxis = DEFAULT_AXIS) {
return ov::is_type<ov::op::v0::Tanh>(node) ||
ov::is_type<ov::op::v0::Gelu>(node) ||
ov::is_type<ov::op::v7::Gelu>(node) ||
ov::is_type<ov::op::v0::Abs>(node) ||
ov::is_type<ov::op::v0::Sqrt>(node) ||
ov::is_type<ov::op::v0::FakeQuantize>(node) ||
canBePerformedAsScaleShift(node, channelAxis);
}

bool isSuitableChildForFusingMatMul(const std::shared_ptr<const Node> &node, const bool canMatMulBeExecutedInI8,
NodeFusingType &updatedChainType, int& fusingAxis) {
// Firsly check for Bias and DQScales fusion
const bool is_bias = ov::is_type<ov::opset1::Add>(node);
const bool is_dq_scales = ov::is_type<ov::opset1::Multiply>(node) && canMatMulBeExecutedInI8;
if (is_bias || is_dq_scales) {
for (const auto &in : node->inputs()) {
const auto& parent_out = in.get_source_output();
const auto& parent = parent_out.get_node_shared_ptr();
const auto& parent_pshape = parent_out.get_partial_shape();
if (ov::is_type<ov::op::v0::MatMul>(parent) && parent_pshape.rank().is_static()) {
if (parent->get_output_target_inputs(0).size() > 1)
break;
const auto bias_port = 1 - in.get_index();
const auto bias_out = node->input_value(bias_port);
if ((bias_out.get_target_inputs().size() > 1) || !ov::op::util::is_on_constant_path(bias_out))
break;
const auto& bias_pshape = bias_out.get_partial_shape();
if (bias_pshape.is_dynamic())
break;
auto getNormalizedPShape = [](const ov::PartialShape &dims, size_t ndims) -> ov::PartialShape {
if (dims.size() >= ndims)
return dims;
ov::PartialShape pshape(std::vector<size_t>(ndims, 1));
std::copy(dims.rbegin(), dims.rend(), pshape.rbegin());
return pshape;
};
const auto bias_pshape_norm = getNormalizedPShape(bias_pshape, parent_pshape.size());
if (fusingAxis >= static_cast<int>(bias_pshape_norm.size()) || fusingAxis >= static_cast<int>(parent_pshape.size()) ||
bias_pshape_norm.size() != parent_pshape.size() || bias_pshape_norm.size() < 2)
break;
if (((bias_pshape_norm[fusingAxis] == parent_pshape[fusingAxis]) || (is_dq_scales && bias_pshape_norm[fusingAxis] == 1)) &&
(bias_pshape_norm[fusingAxis] == static_cast<int64_t>(shape_size(bias_pshape_norm.get_shape()))))
return true;
}
}
}

// MatMul specific checks from ::canFuse()
if (one_of(updatedChainType, NodeFusingType::FusedWithMatMul, NodeFusingType::FusedWithMatMulI8)) {
const auto rank = node->get_output_partial_shape(0).rank();
if (ov::is_type<ov::op::v0::FakeQuantize>(node)) {
if (one_of(node->get_output_element_type(0), ov::element::i8, ov::element::u8) && !canMatMulBeExecutedInI8)
return false;
}
}

// FuseMatMulAndSimpleOperation or FuseFullyConnectedAndSimpleOperation
// Invoke SupportsFusingWithConvolution_Simple directly instead of isSuitableChildForFusingSimple to
// eliminate getNumNonConstInputs() check
if (SupportsFusingWithConvolution_Simple(node, fusingAxis)) {
size_t num_non_const_inputs = 0;
size_t num_mm_inputs = 0;
for (const auto &parent_out : node->input_values()) {
// To avoid endless check `is_on_constant_path` for MatMul branch
if (one_of(GetNodeFusingType(parent_out.get_node_shared_ptr()), NodeFusingType::FusedWithMatMul, NodeFusingType::FusedWithMatMulI8,
NodeFusingType::FusedWithFC, NodeFusingType::FusedWithFCI8))
num_mm_inputs++;
else if (!ov::op::util::is_on_constant_path(parent_out))
num_non_const_inputs++;
}
if (num_non_const_inputs + num_mm_inputs != 1)
return false;

updatedChainType = NodeFusingType::FusedWithMisc;
return true;
}

return false;
}
} // namespace

bool SnippetsMarkSkipped::run_on_model(const std::shared_ptr<ov::Model> &m) {
Expand Down Expand Up @@ -272,6 +394,13 @@ bool SnippetsMarkSkipped::run_on_model(const std::shared_ptr<ov::Model> &m) {
if (isSuitablePoolChild(node)) {
PropagateIfHasOnlyChild(node, fusingChainType);
}
} else if (one_of(fusingChainType, NodeFusingType::FusedWithMatMul, NodeFusingType::FusedWithMatMulI8,
NodeFusingType::FusedWithFC, NodeFusingType::FusedWithFCI8)) {
const bool isExecutedInINT8 = one_of(fusingChainType, NodeFusingType::FusedWithMatMulI8, NodeFusingType::FusedWithFCI8);
// Handle fusings for both MatMul and FullyConnected
NodeFusingType updatedChainType = fusingChainType;
if (isSuitableChildForFusingMatMul(node, isExecutedInINT8, updatedChainType, channelAxis))
PropagateIfHasOnlyChild(node, updatedChainType);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ enum class NodeFusingType : int64_t {
NotSet,
FusedTerminator,
FusedWithConvolution, FusedWithBinaryConvolution,
FusedWithMatMul, FusedWithFC, FusedWithMisc};
FusedWithMatMul, FusedWithFC, FusedWithMatMulI8, FusedWithFCI8, FusedWithMisc};

} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -74,20 +74,30 @@ const std::vector<FullyConnectedParams> activations = {
true, // activation
false, // per-channel
true, // FQ
false, // bias
"fullyConnected,relu_original"
},
{
false, // activation
false, // per-channel
true, // FQ
"fullyConnected_original"
false, // bias
"fullyConnected_original,fullyConnected"
},
{
true, // activation
true, // per-channel
false, // FQ
false, // bias
"fullyConnected,relu_original" // dequantization is not supported for per-channel quantization
},
{
true, // activation
false, // per-channel
true, // FQ
true, // bias
"fullyConnected,fullyConnected/DequantizationMultiply,add,relu"
},
};

INSTANTIATE_TEST_SUITE_P(smoke_LPT, FullyConnectedTransformation,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,21 @@ const std::vector<FullyConnectedParams> activations = {
true, // activation
false, // per-channel
true, // FQ
false, // bias
"fullyconnected,relu_original,relu"
},
{
false, // activation
false, // per-channel
true, // FQ
false, // bias
"fullyConnected_original,fullyConnected"
},
{
true, // activation
true, // per-channel
false, // FQ
false, // bias
"fullyconnected,relu_original,relu"
},
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,21 @@ const std::vector<FullyConnectedParams> activations = {
true, // activation
false, // per-channel
true, // FQ
false, // bias
""
},
{
false, // activation
false, // per-channel
true, // FQ
false, // bias
""
},
{
true, // activation
true, // per-channel
false, // FQ
false, // bias
""
},
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class FullyConnectedParams {
bool activation;
bool perChannelWeights;
bool fq;
bool bias;
std::string originalLayersNames;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ std::string FullyConnectedTransformation::getTestCaseName(const testing::TestPar
"Activation=" << activation.activation << "_" <<
"perChannelWeights=" << activation.perChannelWeights << "_" <<
"FQ=" << activation.fq << "_" <<
"withBias=" << activation.bias << "_" <<
activation.originalLayersNames << "_" <<
expectedPrimitiveType;

Expand All @@ -60,6 +61,7 @@ void FullyConnectedTransformation::SetUp() {
shapes.transposeA,
shapes.transposeB,
weightsType == ov::element::i8,
activation.bias,
activation.perChannelWeights,
activation.activation,
activation.fq);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class MatMulFunction {
const bool transpose1,
const bool transpose2,
const bool signedWeights,
const bool bias,
const bool perChannelWeightsDequantization,
const bool relu,
const bool fq);
Expand Down
7 changes: 7 additions & 0 deletions src/tests/ov_helpers/ov_lpt_models/src/mat_mul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ std::shared_ptr<ov::Model> MatMulFunction::getOriginal(
const bool transpose1,
const bool transpose2,
const bool signedOnWeights,
const bool bias,
const bool perChannelWeightsDequantization,
const bool relu,
const bool fq) {
Expand Down Expand Up @@ -125,6 +126,12 @@ std::shared_ptr<ov::Model> MatMulFunction::getOriginal(
transpose2);
parent->set_friendly_name("fullyConnected");

if (bias) {
auto bias = ov::test::utils::make_constant(precision, parent->get_output_shape(0));
parent = std::make_shared<ov::opset1::Add>(parent, bias);
parent->set_friendly_name("add");
}

if (relu) {
parent = std::make_shared<ov::opset1::Relu>(parent);
parent->set_friendly_name("relu");
Expand Down

0 comments on commit 1c639ef

Please sign in to comment.