Skip to content

Commit

Permalink
[PT FE] Support AWQ models with GEMM module (#27624)
Browse files Browse the repository at this point in the history
### Details:
 - *Support converting models quantized with AWQ algorithm from hf*

### Tickets:
 - *CVS-136653*

---------

Signed-off-by: Maxim Vafin <[email protected]>
  • Loading branch information
mvafin authored Dec 2, 2024
1 parent c692f0b commit d72415f
Show file tree
Hide file tree
Showing 9 changed files with 259 additions and 66 deletions.
12 changes: 0 additions & 12 deletions src/bindings/python/src/openvino/frontend/pytorch/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,15 +177,3 @@ def unpatch_model(model):
log.warning("Exception raised during GPTQ model unpatching. "
"Depending on the exact issue it may lead to broken "
"original model.\n%s", error)


def detect_gptq_model_raw(model):
return (model and getattr(model, 'config', None) and
getattr(model.config, 'quantization_config', None) and
model.config.quantization_config.quant_method == 'gptq')


def detect_gptq_model(model):
return (detect_gptq_model_raw(model) or
getattr(model, 'model', None) and
detect_gptq_model_raw(model.model))
73 changes: 73 additions & 0 deletions src/bindings/python/src/openvino/frontend/pytorch/quantized.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from typing import Optional
import torch
from openvino.frontend.pytorch import ModuleExtension, gptq
from openvino.frontend.pytorch.patch_model import patch_model, unpatch_model


def detect_quantized_model(model: torch.nn.Module) -> Optional[str]:
"""Detects the quantization method used in a given PyTorch model.
Args:
model (torch.nn.Module): The PyTorch model to check for quantization.
Returns:
str: The quantization method if available, otherwise None.
"""
if (model and getattr(model, "config", None)
and getattr(model.config, "quantization_config", None)):
return model.config.quantization_config.quant_method
if getattr(model, "model", None):
return detect_quantized_model(model.model)
return None


def patch_quantized(model: torch.nn.Module) -> None:
"""Patches a model based on its quantization type ("awq" or "gptq").
Args:
model (torch.nn.Module): The model to patch.
Raises:
RuntimeError: If the quantization type is unknown.
"""
quant_type = detect_quantized_model(model)
if quant_type == "awq":
extensions = {}
try:
from awq.modules.linear import WQLinear_GEMM
extensions[WQLinear_GEMM] = ModuleExtension(
WQLinear_GEMM, "ov_ext::awq_gemm",
convert=lambda module, target_op, *args, **kwargs: target_op(
args[0], module.qweight, module.qzeros, module.scales,
torch.tensor(module.group_size),
torch.tensor(module.w_bit), module.bias),
evaluate=lambda module, *args, **kwargs: torch.full(
list(args[0].shape[:-1]) + [module.out_features], 0.5,
dtype=torch.float32)) # type: ignore
except ImportError:
pass
patch_model(model, extensions,
"_openvino_quantized_patch_orig_forward") # type: ignore
elif quant_type == "gptq":
model._openvino_gptq_patched = True
gptq.patch_model(model) # type: ignore
else:
raise RuntimeError(f"Unknown quantization type: {quant_type}.")


def unpatch_quantized(model: torch.nn.Module) -> None:
"""Reverts the patching applied to a quantized PyTorch model.
Args:
model (torch.nn.Module): The model to unpatch.
"""
if getattr(model, "_openvino_gptq_patched", False):
gptq.unpatch_model(model) # type: ignore
del model._openvino_gptq_patched
else:
unpatch_model(model,
"_openvino_quantized_patch_orig_forward") # type: ignore
26 changes: 12 additions & 14 deletions src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
graph_has_ops,
)
from openvino.runtime import opset11 as ops
from openvino.frontend.pytorch import gptq, patch_model
from openvino.frontend.pytorch import quantized, patch_model
from openvino.frontend.pytorch.module_extension import ModuleExtension

import inspect
Expand Down Expand Up @@ -141,27 +141,25 @@ def _get_scripted_model(self, pt_module, example_inputs=None, skip_freeze=False)
patch_model.patch_model(
pt_module, self.module_extensions, orig_forward_name)

gptq_patched = False
if gptq.detect_gptq_model(pt_module):
patched = False
if quantized.detect_quantized_model(pt_module) is not None:
try:
gptq.patch_model(pt_module)
gptq_patched = True
quantized.patch_quantized(pt_module)
patched = True
except Exception as error:
log.warning(
"Failed patching of AutoGPTQ model. Error message:\n%s"
"\nTracing of the model will likely be unsuccessful or incorrect",
error)
gptq.unpatch_model(pt_module)
gptq_patched = False
"Failed patching of AutoGPTQ model. Error message:\n"
"Tracing of the model will likely be unsuccessful or incorrect",
exc_info=error)
quantized.unpatch_quantized(pt_module)
patched = False

try:
scripted = torch.jit.trace(
pt_module, **input_parameters, strict=False)
finally:
if gptq_patched:
gptq.unpatch_model(pt_module)
if self.module_extensions:
patch_model.unpatch_model(pt_module, orig_forward_name)
if patched:
quantized.unpatch_quantized(pt_module)

have_to_freeze_ops = ["prim::Uninitialized",
"prim::unchecked_cast", "aten::append"]
Expand Down
88 changes: 84 additions & 4 deletions src/frontends/pytorch/src/op/linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,19 @@
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/subtract.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

using namespace ov::op;

OutputVector translate_linear(const NodeContext& context) {
// schema: aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor
num_inputs_check(context, 2, 3);
Expand All @@ -20,17 +26,91 @@ OutputVector translate_linear(const NodeContext& context) {
if (weight.get_element_type() == element::f16 || weight.get_element_type() == element::bf16) {
// In case of patched linear it can have mixed fp16/bf16 and fp32 input type.
// In other cases these conversion is not required.
weight = context.mark_node(std::make_shared<ov::op::v1::ConvertLike>(weight, x));
weight = context.mark_node(std::make_shared<v1::ConvertLike>(weight, x));
}
auto matmul = context.mark_node(std::make_shared<ov::op::v0::MatMul>(x, weight, false, true));
auto matmul = context.mark_node(std::make_shared<v0::MatMul>(x, weight, false, true));
if (!context.input_is_none(2)) {
auto bias = context.get_input(2);

if (bias.get_element_type() == element::f16 || bias.get_element_type() == element::bf16) {
// Same reason as for weight.
bias = context.mark_node(std::make_shared<ov::op::v1::ConvertLike>(bias, x));
bias = context.mark_node(std::make_shared<v1::ConvertLike>(bias, x));
}
matmul = context.mark_node(std::make_shared<v1::Add>(matmul, bias));
}
return {matmul};
};

namespace {
uint32_t rearrange_awq_bits(uint32_t num) {
uint32_t result = 0;
uint32_t mask = 0xF;

// Rearrange each 4-bit part in accordance with the AWQ i32->u4 unpacking schema
result |= (num & (mask << 0)) << 0;
result |= (num & (mask << 16)) >> 12;
result |= (num & (mask << 4)) << 4;
result |= (num & (mask << 20)) >> 8;
result |= (num & (mask << 8)) << 8;
result |= (num & (mask << 24)) >> 4;
result |= (num & (mask << 12)) << 12;
result |= (num & (mask << 28)) >> 0;

return result;
}

Output<Node> rearrange_constant(const Output<Node>& c, uint32_t groups) {
auto constant = std::dynamic_pointer_cast<v0::Constant>(c.get_node_shared_ptr());
FRONT_END_OP_CONVERSION_CHECK(constant, "weight must be Constant.");
auto src = constant->get_data_ptr<uint32_t>();
auto initial_shape = constant->get_shape();
FRONT_END_OP_CONVERSION_CHECK(initial_shape.size() == 2, "Only 2D constants are supported.");
auto new_shape = Shape{initial_shape[0] / groups, groups, initial_shape[1] * 8};
auto new_qweight = std::make_shared<v0::Constant>(element::u4, new_shape);
auto dst = const_cast<uint32_t*>(reinterpret_cast<const uint32_t*>(new_qweight->get_data_ptr()));
for (size_t i = 0; i < shape_size(constant->get_shape()); i++) {
dst[i] = rearrange_awq_bits(src[i]);
}
return new_qweight;
}
} // namespace

OutputVector translate_linear_awq(const NodeContext& context) {
num_inputs_check(context, 4, 7);
auto x = context.get_input(0);
auto qweight = context.get_input(1);
auto qzeros = context.get_input(2);
auto scales = context.get_input(3);
auto groups = context.const_input<int64_t>(4);
auto bits = context.const_input<int64_t>(5);

FRONT_END_OP_CONVERSION_CHECK(bits == 4, "Only 4 bit AWQ is supported.");

auto new_qweight = rearrange_constant(qweight, static_cast<uint32_t>(groups));
auto new_qzeros = rearrange_constant(qzeros, 1);
new_qweight = context.mark_node(std::make_shared<v0::Convert>(new_qweight, scales.get_element_type()));
new_qzeros = context.mark_node(std::make_shared<v0::Convert>(new_qzeros, scales.get_element_type()));

auto w_s = context.mark_node(std::make_shared<v1::Subtract>(new_qweight, new_qzeros));
FRONT_END_OP_CONVERSION_CHECK(scales.get_partial_shape().is_static(), "Scales must be constant.");
auto scales_shape = scales.get_shape();
auto new_scales_shape =
v0::Constant::create(element::i32, {3}, std::vector<uint64_t>{scales_shape[0], 1, scales_shape[1]});
scales = context.mark_node(std::make_shared<v1::Reshape>(scales, new_scales_shape, false));
auto weight = context.mark_node(std::make_shared<v1::Multiply>(w_s, scales));
auto out_shape =
v0::Constant::create(element::i32, {2}, std::vector<int32_t>{static_cast<int32_t>(qweight.get_shape()[0]), -1});
weight = context.mark_node(std::make_shared<v1::Reshape>(weight, out_shape, false));
weight = context.mark_node(std::make_shared<v1::ConvertLike>(weight, x));

auto matmul = context.mark_node(std::make_shared<v0::MatMul>(x, weight, false, false));
if (!context.input_is_none(6)) {
auto bias = context.get_input(6);

if (bias.get_element_type() == element::f16 || bias.get_element_type() == element::bf16) {
bias = context.mark_node(std::make_shared<v1::ConvertLike>(bias, x));
}
matmul = context.mark_node(std::make_shared<ov::op::v1::Add>(matmul, bias));
matmul = context.mark_node(std::make_shared<v1::Add>(matmul, bias));
}
return {matmul};
};
Expand Down
7 changes: 5 additions & 2 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ OP_CONVERTER(translate_clamp);
OP_CONVERTER(translate_col2im);
OP_CONVERTER(translate_constant);
OP_CONVERTER(translate_conv_transposend);
OP_CONVERTER(translate_conv1d_ext);
OP_CONVERTER(translate_convnd);
OP_CONVERTER(translate_convolution);
OP_CONVERTER(translate_convolution_mode);
Expand All @@ -77,7 +76,6 @@ OP_CONVERTER(translate_dot);
OP_CONVERTER(translate_elu);
OP_CONVERTER(translate_embedding);
OP_CONVERTER(translate_embedding_bag);
OP_CONVERTER(translate_embedding_ext);
OP_CONVERTER(translate_empty);
OP_CONVERTER(translate_empty_like);
OP_CONVERTER(translate_erf);
Expand Down Expand Up @@ -325,6 +323,10 @@ OP_CONVERTER(translate_unbind_int_fx);
OP_CONVERTER(translate_unique2);
OP_CONVERTER(translate_zeros_fx);
OP_CONVERTER(translate_zeros_like_fx);
// Extensions
OP_CONVERTER(translate_conv1d_ext);
OP_CONVERTER(translate_embedding_ext);
OP_CONVERTER(translate_linear_awq);

} // namespace op

Expand Down Expand Up @@ -699,6 +701,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::zero", op::translate_zeros_like},
{"aten::zeros", op::translate_zeros},
{"aten::zeros_like", op::translate_zeros_like},
{"ov_ext::awq_gemm", op::translate_linear_awq},
{"ov_ext::embedding", op::translate_embedding_ext},
{"ov_ext::conv1d", op::translate_conv1d_ext},
{"ov_ext::linear", op::translate_linear},
Expand Down
6 changes: 5 additions & 1 deletion src/frontends/pytorch/src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ using namespace ov::op;

void num_inputs_check(const NodeContext& context, size_t min_inputs, size_t max_inputs) {
auto num_inputs = context.get_input_size();
FRONT_END_OP_CONVERSION_CHECK(num_inputs >= min_inputs, "Got less inputs than expected");
FRONT_END_OP_CONVERSION_CHECK(num_inputs >= min_inputs,
"Got less inputs ",
num_inputs,
" than expected ",
min_inputs);
for (auto i = max_inputs; i < num_inputs; i++) {
FRONT_END_OP_CONVERSION_CHECK(context.input_is_none(i), "Got more inputs than expected.");
}
Expand Down
7 changes: 0 additions & 7 deletions tests/model_hub_tests/pytorch/detectron2_precommit
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
COCO-Detection/faster_rcnn_R_50_C4_1x,none
COCO-Detection/faster_rcnn_R_50_DC5_3x,none
COCO-Detection/faster_rcnn_R_50_FPN_1x,none
COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x,none
COCO-Detection/retinanet_R_50_FPN_1x,none
COCO-Detection/rpn_R_50_C4_1x,none
COCO-Detection/rpn_R_50_FPN_1x,none
COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x,none
COCO-InstanceSegmentation/mask_rcnn_R_50_C4_3x,none
COCO-InstanceSegmentation/mask_rcnn_R_50_DC5_3x,none
COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x,none
COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x,none
COCO-Keypoints/keypoint_rcnn_X_101_32x8d_FPN_3x,none
Expand All @@ -19,8 +14,6 @@ LVISv0.5-InstanceSegmentation/mask_rcnn_R_50_FPN_1x,none
LVISv0.5-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_1x,none
Misc/cascade_mask_rcnn_R_50_FPN_3x,none
Misc/cascade_mask_rcnn_X_152_32x8d_FPN_IN5k_gn_dconv,none
Misc/mask_rcnn_R_50_FPN_3x_dconv_c3-c5,none
Misc/mask_rcnn_R_50_FPN_3x_gn,none
Misc/mask_rcnn_R_50_FPN_3x_syncbn,none
Misc/scratch_mask_rcnn_R_50_FPN_9x_syncbn,none
PascalVOC-Detection/faster_rcnn_R_50_C4,none
Loading

0 comments on commit d72415f

Please sign in to comment.