Skip to content

Commit

Permalink
Group & NF4 decompression temporary disabled
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Oct 6, 2023
1 parent 32ff879 commit cd8178d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,14 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
} else {
// We need to fuse Transpose to MatMul to have a simpler callback for the next transformation
CPU_REGISTER_PASS_COMMON(manager, ov::pass::TransposeMatMul);
const ov::element::TypeVector decompression_precisions{
ov::element::u8,
// TODO: Uncomment when group decompression is supported
// ov::element::nf4
};
// MarkDequantizationSubgraph is used even in non-LPT pipeline on X64 platforms
// in order to keep compressed MatMul weights with decompression operations as is
CPU_REGISTER_PASS_X64(manager, ov::pass::MarkDequantizationSubgraph, ov::element::TypeVector{ov::element::u8, ov::element::nf4}, true);
CPU_REGISTER_PASS_X64(manager, ov::pass::MarkDequantizationSubgraph, decompression_precisions, true);
CPU_SET_CALLBACK_X64(manager, [](const_node_ptr &node) -> bool {
auto get_single_consumer = [](const_node_ptr &node) -> std::shared_ptr<ov::Node> {
const auto consumers = node->get_output_target_inputs(0);
Expand All @@ -226,12 +231,14 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis

if (ov::is_type<ov::opset1::MatMul>(consumer)) {
return false;
} else if (ov::is_type<ov::opset1::Reshape>(consumer)) {
consumer = get_single_consumer(consumer);
if (consumer != nullptr && ov::is_type<ov::opset1::MatMul>(consumer)) {
return false;
}
}
// TODO: Uncomment when group decompression is supported
// else if (ov::is_type<ov::opset1::Reshape>(consumer)) {
// consumer = get_single_consumer(consumer);
// if (consumer != nullptr && ov::is_type<ov::opset1::MatMul>(consumer)) {
// return false;
// }
// }
return true;
}, ov::pass::MarkDequantizationSubgraph);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,10 @@ class MatmulWeightsDecompression : public testing::WithParamInterface<MatmulWeig
void checkResults() {
const auto& test_param = GetParam();
const auto& weights_precision = std::get<1>(test_param);
// TODO: remove this condition when group decompression is supported
if (weights_precision == ov::element::nf4 || std::get<0>(test_param).weights_group_size != 1) {
return;
}
bool weights_found = false;
for (const auto& n : compiledModel.get_runtime_model()->get_ordered_ops()) {
if (n->get_friendly_name() == "Compressed_weights") {
Expand Down

0 comments on commit cd8178d

Please sign in to comment.