Skip to content

Commit

Permalink
[GPU] Fix issue for Flux.1 int4 WC model memory overuse. (openvinotoo…
Browse files Browse the repository at this point in the history
…lkit#27376)

Flux.1 text_encoder_2_int4 model has unusual WC pattern. (Needs extra
multiply)
Add this in the target matcher.


### Tickets:
 - *154194*

---------

Signed-off-by: hyunback <[email protected]>
  • Loading branch information
hyunback authored Nov 5, 2024
1 parent 9eab491 commit b8c5bd8
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,16 @@ ConvertFullyConnectedToFullyConnectedCompressed::ConvertFullyConnectedToFullyCon
auto reshape_const_m = wrap_type<ov::op::v0::Constant>();
auto reshape_m = wrap_type<ov::op::v1::Reshape>({mul_m, reshape_const_m}, reshape_3d_to_2d);

auto mul2_const_m = wrap_type<ov::op::v0::Constant>();
auto mul2_m = wrap_type<ov::op::v1::Multiply>({reshape_m, mul2_const_m});

auto transpose_input = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{reshape_m, mul_m});
auto transpose_const_m = wrap_type<ov::op::v0::Constant>();
auto transpose_m = wrap_type<ov::op::v1::Transpose>({transpose_input, transpose_const_m});

auto data_m = any_input();
auto bias_m = any_input();
auto weights_input_m = std::make_shared<ov::pass::pattern::op::Or>(ov::OutputVector{reshape_m, transpose_m, mul_m});
auto weights_input_m = std::make_shared<ov::pass::pattern::op::Or>(ov::OutputVector{reshape_m, transpose_m, mul_m, mul2_m});
auto fully_connected_m = wrap_type<op::FullyConnected>({data_m, weights_input_m, bias_m});

ov::matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) {
Expand Down Expand Up @@ -131,6 +134,7 @@ ConvertFullyConnectedToFullyConnectedCompressed::ConvertFullyConnectedToFullyCon
std::shared_ptr<ov::Node> fc_input_zp = optional_zero_point;
std::shared_ptr<ov::Node> fc_input_bias = pattern_map.at(bias_m).get_node_shared_ptr();
std::vector<std::shared_ptr<ov::Node>> result_nodes = {};

if (has_transpose) {
const auto& transpose = pattern_map.at(transpose_m).get_node_shared_ptr();
std::shared_ptr<ov::Node> transpose_const = pattern_map.at(transpose_const_m).get_node_shared_ptr();
Expand All @@ -151,6 +155,11 @@ ConvertFullyConnectedToFullyConnectedCompressed::ConvertFullyConnectedToFullyCon
}
}

if (pattern_map.count(mul2_m)) {
auto mul2_op_const = std::dynamic_pointer_cast<ov::op::v0::Constant>(pattern_map.at(mul2_const_m).get_node_shared_ptr());
fc_input_scale = ov::op::util::eltwise_fold<ov::op::v1::Multiply>(fc_input_scale, mul2_op_const).get_node_shared_ptr();
}

std::shared_ptr<ov::Node> new_fc = nullptr;
if (with_zero_point) {
new_fc = std::make_shared<op::FullyConnectedCompressed>(fc_input_a,
Expand All @@ -171,6 +180,7 @@ ConvertFullyConnectedToFullyConnectedCompressed::ConvertFullyConnectedToFullyCon
new_fc->set_friendly_name(fc->get_friendly_name());
ov::copy_runtime_info(m.get_matched_nodes(), result_nodes);
ov::replace_node(fc, new_fc);

return true;
};

Expand Down
19 changes: 17 additions & 2 deletions src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,14 +180,29 @@ static bool is_decompression_multiply(const std::shared_ptr<const ov::Node> node
if (all_has_types(consumers, { ov::op::v0::MatMul::get_type_info_static(), ov::op::v8::Gather::get_type_info_static() }))
return true;

auto are_multiply_from_decompression = [&all_has_types](const ov::Input<ov::Node> consumer) {
if (!cldnn::one_of(consumer.get_node()->get_type_info(), { ov::op::v1::Multiply::get_type_info_static() }))
return false;
const auto child_consumers = consumer.get_node()->get_output_target_inputs(0);
if (all_has_types(child_consumers, { ov::opset1::MatMul::get_type_info_static(), ov::op::v8::Gather::get_type_info_static() }))
return true;
return false;
};

auto are_converts_from_decompression = [&all_has_types](const std::set<ov::Input<ov::Node>>& consumers) {
auto are_converts_from_decompression = [&all_has_types, &are_multiply_from_decompression](const std::set<ov::Input<ov::Node>>& consumers) {
if (!all_has_types(consumers, { ov::opset1::Convert::get_type_info_static() }))
return false;
for (const auto& consumer : consumers) {
const auto child_consumers = consumer.get_node()->get_output_target_inputs(0);
if (!all_has_types(child_consumers, { ov::opset1::MatMul::get_type_info_static(), ov::op::v8::Gather::get_type_info_static() }))
for (const auto& child_consumer : child_consumers) {
const auto& type_info = child_consumer.get_node()->get_type_info();
if (cldnn::one_of(type_info, { ov::opset1::MatMul::get_type_info_static(), ov::op::v8::Gather::get_type_info_static() }))
continue;
if (are_multiply_from_decompression(child_consumer)) {
continue;
}
return false;
}
}
return true;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ using ov::test::InputShape;
* \ /
* Multiply
* |
* Data(F32) Transpose(optional)
* Data(F32) Transpose(optional) or Multiply(optional)
* \ /
* Matmul
* |
Expand All @@ -56,6 +56,7 @@ using MatmulWeightsDecompressionParams = std::tuple<ShapeParams, //
bool, // transpose on weights
bool, // decompression subtract
bool, // reshape on decompression constants
bool, // extra multiply
bool, // per-tensor zero-point
uint64_t // dynamic_quantization_group_size
>;
Expand All @@ -70,6 +71,7 @@ class MatmulWeightsDecompression : public testing::WithParamInterface<MatmulWeig
bool transpose;
bool decompression_sub;
bool reshape_on_decompression;
bool extra_multiply;
bool per_tensor_zp;
uint64_t dyn_quan_group_size;

Expand All @@ -79,6 +81,7 @@ class MatmulWeightsDecompression : public testing::WithParamInterface<MatmulWeig
transpose,
decompression_sub,
reshape_on_decompression,
extra_multiply,
per_tensor_zp,
dyn_quan_group_size) = obj.param;

Expand All @@ -95,6 +98,7 @@ class MatmulWeightsDecompression : public testing::WithParamInterface<MatmulWeig
result << "transpose_weights=" << transpose << "_";
result << "decompression_subtract=" << decompression_sub << "_";
result << "reshape_on_decompression=" << reshape_on_decompression << "_";
result << "extra_multiply=" << extra_multiply << "_";
result << "per_tensor_zp=" << per_tensor_zp << "_";
result << "dyn_quan_group_size=" << dyn_quan_group_size;

Expand All @@ -110,6 +114,7 @@ class MatmulWeightsDecompression : public testing::WithParamInterface<MatmulWeig
const bool transpose_weights,
const bool add_subtract,
const bool reshape_on_decompression,
const bool extra_multiply,
const bool per_tensor_zp) {
ov::ParameterVector params{std::make_shared<ov::op::v0::Parameter>(data_precision, data_shape)};
const auto weights_subgraph = init_compressed_weights_subgraph(weights_shape,
Expand All @@ -119,6 +124,7 @@ class MatmulWeightsDecompression : public testing::WithParamInterface<MatmulWeig
transpose_weights,
add_subtract,
reshape_on_decompression,
extra_multiply,
per_tensor_zp);

auto mat_mul = std::make_shared<ov::op::v0::MatMul>(params[0], weights_subgraph);
Expand All @@ -132,6 +138,7 @@ class MatmulWeightsDecompression : public testing::WithParamInterface<MatmulWeig
const bool transpose_weights,
const bool add_subtract,
const bool reshape_on_decompression_constant,
const bool extra_multiply,
const bool per_tensor_zp) {
auto transpose_if_necessary = [&](const ov::Shape& shape) {
auto result_shape = shape;
Expand Down Expand Up @@ -229,6 +236,8 @@ class MatmulWeightsDecompression : public testing::WithParamInterface<MatmulWeig
std::swap(*order.rbegin(), *(order.rbegin() + 1));
auto transpose_constant = ov::op::v0::Constant::create(ov::element::i32, {rank}, order);
last_node = std::make_shared<ov::op::v1::Transpose>(last_node, transpose_constant);
} else if (extra_multiply) {
last_node = std::make_shared<ov::op::v1::Multiply>(last_node, scale_const);
}
return last_node;
}
Expand All @@ -242,6 +251,7 @@ class MatmulWeightsDecompression : public testing::WithParamInterface<MatmulWeig
bool transpose_weights;
bool decompression_sub;
bool reshape_on_decompression;
bool extra_multiply;
bool per_tensor_zp;
uint64_t dyn_quan_group_size;

Expand All @@ -251,6 +261,7 @@ class MatmulWeightsDecompression : public testing::WithParamInterface<MatmulWeig
transpose_weights,
decompression_sub,
reshape_on_decompression,
extra_multiply,
per_tensor_zp,
dyn_quan_group_size) = GetParam();

Expand All @@ -266,6 +277,7 @@ class MatmulWeightsDecompression : public testing::WithParamInterface<MatmulWeig
transpose_weights,
decompression_sub,
reshape_on_decompression,
extra_multiply,
per_tensor_zp);


Expand Down Expand Up @@ -328,6 +340,20 @@ INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_basic,
::testing::Values(true),
::testing::Values(true),
::testing::Values(false),
::testing::Values(false),
::testing::Values(0)),
MatmulWeightsDecompression::get_test_case_name);

INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_extra_multiply,
MatmulWeightsDecompression,
::testing::Combine(::testing::ValuesIn(input_shapes_basic),
::testing::ValuesIn(weights_precisions),
::testing::ValuesIn(activations_precisions),
::testing::Values(false),
::testing::Values(false),
::testing::Values(false),
::testing::Values(true),
::testing::Values(false),
::testing::Values(0)),
MatmulWeightsDecompression::get_test_case_name);

Expand Down Expand Up @@ -356,6 +382,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_corner_cases_basic,
::testing::ValuesIn(transpose_weights),
::testing::ValuesIn(add_decompression_sub),
::testing::ValuesIn(reshape_on_decompression),
::testing::Values(false),
::testing::ValuesIn(per_tensor_zp),
::testing::Values(0)),
MatmulWeightsDecompression::get_test_case_name);
Expand All @@ -368,6 +395,7 @@ INSTANTIATE_TEST_SUITE_P(MatMulCompressedWeights_corner_cases_big,
::testing::ValuesIn(transpose_weights),
::testing::ValuesIn(add_decompression_sub),
::testing::ValuesIn(reshape_on_decompression),
::testing::Values(false),
::testing::ValuesIn(per_tensor_zp),
::testing::Values(0)),
MatmulWeightsDecompression::get_test_case_name);
Expand All @@ -384,6 +412,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_dyn_quan,
::testing::Values(false),
::testing::ValuesIn(add_decompression_sub),
::testing::Values(true),
::testing::Values(false),
::testing::Values(true), // per_tensor_zp
::testing::Values(UINT64_MAX)),
MatmulWeightsDecompression::get_test_case_name);
Expand Down

0 comments on commit b8c5bd8

Please sign in to comment.