Skip to content

Commit

Permalink
ONNX WOQ supports different dtypes (#1490)
Browse files Browse the repository at this point in the history
ONNX WOQ supports different dtypes

Signed-off-by: Mengni Wang <[email protected]>
Signed-off-by: yuwenzho <[email protected]>
  • Loading branch information
mengniwang95 authored Dec 23, 2023
1 parent 08221e1 commit 5119fcb
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 46 deletions.
9 changes: 5 additions & 4 deletions neural_compressor/adaptor/onnxrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,12 +979,10 @@ def _pre_optimize(self, model, level=1):
sess_options.register_custom_ops_library(get_library_path())

if not model.is_large_model:
sess = ort.InferenceSession(
model.model.SerializeToString(), sess_options, providers=["CPUExecutionProvider"]
)
sess = ort.InferenceSession(model.model.SerializeToString(), sess_options, providers=[self.backend])
elif model.model_path is not None: # pragma: no cover
model.model = onnx.ModelProto() # clean memory for large model
sess = ort.InferenceSession(model.model_path, sess_options, providers=["CPUExecutionProvider"])
sess = ort.InferenceSession(model.model_path, sess_options, providers=[self.backend])
else: # pragma: no cover
logger.warning("Please use model path instead of onnx model object to quantize")
del sess
Expand Down Expand Up @@ -1914,6 +1912,7 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
mse=mse,
perchannel=perchannel,
accuracy_level=accuracy_level,
providers=[self.backend],
)
if "AWQ" in algos:
from neural_compressor.adaptor.ox_utils.weight_only import awq_quantize
Expand All @@ -1931,6 +1930,7 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
enable_auto_scale=enable_auto_scale,
enable_mse_search=enable_mse_search,
accuracy_level=accuracy_level,
providers=[self.backend],
)
elif "RTN" in algos:
from neural_compressor.adaptor.ox_utils.weight_only import rtn_quantize
Expand All @@ -1940,6 +1940,7 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
tmp_model,
quant_config,
accuracy_level=accuracy_level,
providers=[self.backend],
)
tmp_model.q_config = copy.deepcopy(quant_config)
self._dump_model_op_stats(tmp_model, tune_cfg)
Expand Down
21 changes: 21 additions & 0 deletions neural_compressor/adaptor/onnxrt_cuda.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,20 @@
-
version:
name: '1.6.0'
weight_only_integer: &cap_weight_only {
'MatMul': &cap_weight_only_matmul {
'weight': {
'dtype': ['int'], # no need to care uint
'bits': [4, 3, 8], # [1-8]
'group_size': [32, -1, 1, 16, 64, 128, 256, 512, 1024], # [1-inf]
'scheme': ['sym', 'asym'], # sym, no ZP
'algorithm': ['RTN', 'AWQ', 'GPTQ']
},
'activation': {
'dtype': ['fp32']
}
},
}
int8: &ref_1_6 {
'static': &ref_1_6_static {
'Conv': {
Expand Down Expand Up @@ -114,6 +128,7 @@
-
version:
name: '1.7.0'
weight_only_integer: *cap_weight_only
int8: {
'static': {
'FusedConv': {
Expand Down Expand Up @@ -155,6 +170,7 @@
-
version:
name: '1.8.0'
weight_only_integer: *cap_weight_only
int8: {
'static': {
'FusedConv': {
Expand Down Expand Up @@ -224,6 +240,7 @@
-
version:
name: '1.9.0'
weight_only_integer: *cap_weight_only
int8: {
'static': {
'FusedConv': {
Expand Down Expand Up @@ -300,6 +317,7 @@
-
version:
name: '1.10.0'
weight_only_integer: *cap_weight_only
int8: {
'static': {
'FusedConv': {
Expand Down Expand Up @@ -356,6 +374,7 @@
-
version:
name: '1.11.0'
weight_only_integer: *cap_weight_only
int8: &ref_1_11 {
'static': {
'FusedConv': {
Expand Down Expand Up @@ -427,6 +446,7 @@
-
version:
name: '1.12.0'
weight_only_integer: *cap_weight_only
int8: *ref_1_11
fp16: *common_fp16
bf16: *common_bf16
Expand All @@ -436,6 +456,7 @@
-
version:
name: 'default'
weight_only_integer: *cap_weight_only
int8: *ref_1_6
fp16: *common_fp16
bf16: *common_bf16
Expand Down
3 changes: 3 additions & 0 deletions neural_compressor/adaptor/ox_utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@

dtype_mapping = {
"fp32": 1,
"float32": 1,
"uint8": 2,
"int8": 3,
"uint16": 4,
Expand All @@ -66,12 +67,14 @@
"string": 8,
"bool": 9,
"fp16": 10,
"float16": 10,
"double": 11,
"uint32": 12,
"uint64": 13,
"complex64": 14,
"complex128": 15,
"bf16": 16,
"bfloat16": 16,
}

PROVIDERS = {
Expand Down
Loading

0 comments on commit 5119fcb

Please sign in to comment.