diff --git a/neural_compressor/adaptor/onnxrt.py b/neural_compressor/adaptor/onnxrt.py index 8247bc2a33e..738aa7833d0 100644 --- a/neural_compressor/adaptor/onnxrt.py +++ b/neural_compressor/adaptor/onnxrt.py @@ -979,12 +979,10 @@ def _pre_optimize(self, model, level=1): sess_options.register_custom_ops_library(get_library_path()) if not model.is_large_model: - sess = ort.InferenceSession( - model.model.SerializeToString(), sess_options, providers=["CPUExecutionProvider"] - ) + sess = ort.InferenceSession(model.model.SerializeToString(), sess_options, providers=[self.backend]) elif model.model_path is not None: # pragma: no cover model.model = onnx.ModelProto() # clean memory for large model - sess = ort.InferenceSession(model.model_path, sess_options, providers=["CPUExecutionProvider"]) + sess = ort.InferenceSession(model.model_path, sess_options, providers=[self.backend]) else: # pragma: no cover logger.warning("Please use model path instead of onnx model object to quantize") del sess @@ -1914,6 +1912,7 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None): mse=mse, perchannel=perchannel, accuracy_level=accuracy_level, + providers=[self.backend], ) if "AWQ" in algos: from neural_compressor.adaptor.ox_utils.weight_only import awq_quantize @@ -1931,6 +1930,7 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None): enable_auto_scale=enable_auto_scale, enable_mse_search=enable_mse_search, accuracy_level=accuracy_level, + providers=[self.backend], ) elif "RTN" in algos: from neural_compressor.adaptor.ox_utils.weight_only import rtn_quantize @@ -1940,6 +1940,7 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None): tmp_model, quant_config, accuracy_level=accuracy_level, + providers=[self.backend], ) tmp_model.q_config = copy.deepcopy(quant_config) self._dump_model_op_stats(tmp_model, tune_cfg) diff --git a/neural_compressor/adaptor/onnxrt_cuda.yaml b/neural_compressor/adaptor/onnxrt_cuda.yaml index 562fe410758..0f2547a62e3 100644 --- a/neural_compressor/adaptor/onnxrt_cuda.yaml +++ b/neural_compressor/adaptor/onnxrt_cuda.yaml @@ -17,6 +17,20 @@ - version: name: '1.6.0' + weight_only_integer: &cap_weight_only { + 'MatMul': &cap_weight_only_matmul { + 'weight': { + 'dtype': ['int'], # no need to care uint + 'bits': [4, 3, 8], # [1-8] + 'group_size': [32, -1, 1, 16, 64, 128, 256, 512, 1024], # [1-inf] + 'scheme': ['sym', 'asym'], # sym, no ZP + 'algorithm': ['RTN', 'AWQ', 'GPTQ'] + }, + 'activation': { + 'dtype': ['fp32'] + } + }, + } int8: &ref_1_6 { 'static': &ref_1_6_static { 'Conv': { @@ -114,6 +128,7 @@ - version: name: '1.7.0' + weight_only_integer: *cap_weight_only int8: { 'static': { 'FusedConv': { @@ -155,6 +170,7 @@ - version: name: '1.8.0' + weight_only_integer: *cap_weight_only int8: { 'static': { 'FusedConv': { @@ -224,6 +240,7 @@ - version: name: '1.9.0' + weight_only_integer: *cap_weight_only int8: { 'static': { 'FusedConv': { @@ -300,6 +317,7 @@ - version: name: '1.10.0' + weight_only_integer: *cap_weight_only int8: { 'static': { 'FusedConv': { @@ -356,6 +374,7 @@ - version: name: '1.11.0' + weight_only_integer: *cap_weight_only int8: &ref_1_11 { 'static': { 'FusedConv': { @@ -427,6 +446,7 @@ - version: name: '1.12.0' + weight_only_integer: *cap_weight_only int8: *ref_1_11 fp16: *common_fp16 bf16: *common_bf16 @@ -436,6 +456,7 @@ - version: name: 'default' + weight_only_integer: *cap_weight_only int8: *ref_1_6 fp16: *common_fp16 bf16: *common_bf16 diff --git a/neural_compressor/adaptor/ox_utils/util.py b/neural_compressor/adaptor/ox_utils/util.py index fb7fa97e047..8a393bbbb0a 100644 --- a/neural_compressor/adaptor/ox_utils/util.py +++ b/neural_compressor/adaptor/ox_utils/util.py @@ -57,6 +57,7 @@ dtype_mapping = { "fp32": 1, + "float32": 1, "uint8": 2, "int8": 3, "uint16": 4, @@ -66,12 +67,14 @@ "string": 8, "bool": 9, "fp16": 10, + "float16": 10, "double": 11, "uint32": 12, "uint64": 13, "complex64": 14, "complex128": 15, "bf16": 16, + "bfloat16": 16, } PROVIDERS = { diff --git a/neural_compressor/adaptor/ox_utils/weight_only.py b/neural_compressor/adaptor/ox_utils/weight_only.py index e77e2fecf7e..db32e99974a 100644 --- a/neural_compressor/adaptor/ox_utils/weight_only.py +++ b/neural_compressor/adaptor/ox_utils/weight_only.py @@ -29,7 +29,7 @@ from onnx import onnx_pb as onnx_proto from packaging.version import Version -from neural_compressor.adaptor.ox_utils.util import simple_progress_bar +from neural_compressor.adaptor.ox_utils.util import dtype_mapping, simple_progress_bar from neural_compressor.model.model import BaseModel from neural_compressor.model.onnx_model import ONNXModel from neural_compressor.utils.utility import LazyImport @@ -103,9 +103,13 @@ def make_matmul_weight_only_node( packed = np.reshape(packed, (-1, k_blocks, blob_size)) # build scale tensor - scale = np.reshape(scale, (-1, k_blocks)).astype("float32") + scale = np.reshape(scale, (-1, k_blocks)) scale_tensor = onnx.helper.make_tensor( - name=node.input[1] + "_scale", data_type=1, dims=scale.shape, vals=scale.tobytes(), raw=True + name=node.input[1] + "_scale", + data_type=dtype_mapping[str(scale.dtype)], + dims=scale.shape, + vals=scale.tobytes(), + raw=True, ) input_names.append(scale_tensor.name) new_inits.append(scale_tensor) @@ -138,7 +142,7 @@ def make_matmul_weight_only_node( kwargs["bits"] = num_bits kwargs["block_size"] = group_size if accuracy_level > 0: - # require onnxruntime > 1.16.2 + # require onnxruntime > 1.16.3 kwargs["accuracy_level"] = accuracy_level else: @@ -219,17 +223,17 @@ def quant_tensor(data, num_bits=4, group_size=32, scheme="asym", dtype="int", ra rmax = np.max(data, axis=1, keepdims=True) * ratio if scheme == "sym": max_range = np.maximum(np.abs(rmin), np.abs(rmax)) - scale = np.ones(rmax.shape, dtype="float32") + scale = np.ones(rmax.shape) scale[max_range > 0] = np.array( - [float(i) / (maxq - minq) for i in (max_range[max_range > 0] * 2.0).flatten().tolist()], dtype="float32" + [float(i) / (maxq - minq) for i in (max_range[max_range > 0] * 2.0).flatten().tolist()] ) zero_point = ( np.zeros(scale.shape) if dtype == "int" else np.ones(rmax.shape, dtype="uint8") * (1 << (num_bits - 1)) ) else: - scale = np.ones(rmax.shape, dtype="float32") + scale = np.ones(rmax.shape) scale[rmin != rmax] = np.array( - [float(i) / (maxq - minq) for i in (rmax - rmin)[rmin != rmax].flatten().tolist()], dtype="float32" + [float(i) / (maxq - minq) for i in (rmax - rmin)[rmin != rmax].flatten().tolist()] ) zero_point = ( ((np.zeros(scale.shape) - rmin) / scale).round() @@ -290,6 +294,7 @@ def rtn_quantize( scheme="asym", ratios={}, accuracy_level=0, + providers=["CPUExecutionProvider"], ): """Quant the model with round to nearst method. @@ -313,6 +318,7 @@ def rtn_quantize( accuracy_level (int): accuracy level. Support 0 (unset), 1(fp32 compute type of jblas kernel), 2 (fp16 compute type of jblas kernel), 3 (bf16 compute type of jblas kernel), 4 (int8 compute type of jblas kernel) + providers (list): providers to use Returns: model: fake quantized ONNXModel @@ -352,11 +358,16 @@ def rtn_quantize( weight = pad_tensor(weight, group_size, k_blocks) - if (Version(ort.__version__) > ONNXRT1161_VERSION and num_bits == 4) or ( + satisfy_MatMulNBits_condition = Version(ort.__version__) > ONNXRT1161_VERSION and num_bits == 4 + satisfy_MatMulFpQ4_condition = ( Version(ort.__version__) >= ONNXRT116_VERSION and num_bits == 4 and group_size == 32 + ) + if ("CUDAExecutionProvider" in providers and satisfy_MatMulNBits_condition) or ( + "CUDAExecutionProvider" not in providers + and (satisfy_MatMulFpQ4_condition or satisfy_MatMulNBits_condition) ): # pragma: no cover - # MatMulFpQ4 support 4 bits and 32 group_size with ort 1.16.0 and 1.16.1 versions - # MatMulNBits supports 4 bits and 2^n group_size with ort > 1.16.1 + # MatMulFpQ4 support 4 bits and 32 group_size with ort 1.16.0 and 1.16.1 versions, supported by CPU EP + # MatMulNBits supports 4 bits and 2^n group_size with ort > 1.16.1, supported by CPU EP AND CUDA EP q_weight, scale, zp = quant_tensor( weight.T, num_bits, group_size, scheme, "uint", ratios.get(node.input[1], 1) ) @@ -367,7 +378,7 @@ def rtn_quantize( group_size=group_size, k_blocks=k_blocks, q_weight=q_weight.astype("uint8"), - scale=scale, + scale=scale.astype(dtype), zero_point=zp if scheme == "asym" else None, accuracy_level=accuracy_level, ) @@ -379,10 +390,10 @@ def rtn_quantize( q_weight = qdq_tensor(weight.T, num_bits, group_size, scheme, "int", ratios.get(node.input[1], 1)) q_weight = np.reshape(q_weight, (org_w_shape[1], -1)) q_weight = np.transpose(q_weight) - q_weight = q_weight[: org_w_shape[0], :].astype(weight.dtype) + q_weight = q_weight[: org_w_shape[0], :].astype(dtype) q_weight_tensor = onnx.helper.make_tensor( name=node.input[1] + "_Q{}G{}".format(str(num_bits), str(group_size)), - data_type=1, + data_type=dtype_mapping[str(dtype)], dims=weight.shape, vals=q_weight.tobytes(), raw=True, @@ -425,6 +436,7 @@ def apply_awq_scale(model, weight_config, absorb_pairs, output_dicts, num_bits, continue inp = np.concatenate(output_dicts[nodes[0].input[0]], axis=0) inp_scale = np.mean(np.reshape(np.abs(inp), (-1, inp[0].shape[-1])), axis=0) + dtype = None weight = [] org_out = [] for node in nodes: @@ -490,11 +502,17 @@ def apply_awq_scale(model, weight_config, absorb_pairs, output_dicts, num_bits, init_share_num = model.get_initializer_share_num(node.input[1]) weight_tensor = model.get_initializer(node.input[1]) tensor = numpy_helper.to_array(weight_tensor, base_dir) - + dtype = tensor.dtype tensor = tensor.T * best_scale - tensor = (tensor.T).astype("float32") - - new_tensor = onnx.helper.make_tensor(node.input[1] + "_scaled", 1, tensor.shape, tensor.tobytes(), raw=True) + tensor = (tensor.T).astype(dtype) + + new_tensor = onnx.helper.make_tensor( + name=node.input[1] + "_scaled", + data_type=dtype_mapping[str(dtype)], + dims=tensor.shape, + vals=tensor.tobytes(), + raw=True, + ) model.add_initializer(new_tensor) node.input[1] = new_tensor.name @@ -510,8 +528,9 @@ def apply_awq_scale(model, weight_config, absorb_pairs, output_dicts, num_bits, ) == len(nodes): for idx in [1, 2]: tensor = numpy_helper.to_array(model.get_initializer(parent.input[idx]), base_dir) + dtype = tensor.dtype new_tensor = tensor / np.reshape(best_scale, (1, -1)) - model.set_initializer(parent.input[idx], new_tensor.astype(tensor.dtype), raw=True) + model.set_initializer(parent.input[idx], new_tensor.astype(dtype), raw=True) updated_nodes.append(parent.name) output_dicts[parent.output[0]] = output_dicts[parent.output[0]] / np.reshape(best_scale, (1, -1)) @@ -523,8 +542,9 @@ def apply_awq_scale(model, weight_config, absorb_pairs, output_dicts, num_bits, for inp in parent.input: if model.get_initializer(inp) is not None: tensor = numpy_helper.to_array(model.get_initializer(inp), base_dir) + dtype = tensor.dtype new_tensor = tensor / np.reshape(best_scale, (1, -1)) - model.set_initializer(inp, new_tensor.astype(tensor.dtype), raw=True) + model.set_initializer(inp, new_tensor.astype(dtype), raw=True) updated_nodes.append(parent.name) output_dicts[parent.output[0]] = output_dicts[parent.output[0]] / np.reshape(best_scale, (1, -1)) @@ -532,8 +552,9 @@ def apply_awq_scale(model, weight_config, absorb_pairs, output_dicts, num_bits, nodes ): # pragma: no cover tensor = numpy_helper.to_array(model.get_initializer(parent.input[2]), base_dir) + dtype = tensor.dtype new_tensor = tensor / np.reshape(best_scale, (1, -1)) - model.set_initializer(parent.input[2], new_tensor.astype(tensor.dtype), raw=True) + model.set_initializer(parent.input[2], new_tensor.astype(dtype), raw=True) updated_nodes.append(parent.name) output_dicts[parent.output[0]] = output_dicts[parent.output[0]] / np.reshape(best_scale, (1, -1)) @@ -541,7 +562,7 @@ def apply_awq_scale(model, weight_config, absorb_pairs, output_dicts, num_bits, # insert mul scale_tensor = helper.make_tensor( name=parent.output[0] + "_weight_only_scale", - data_type=onnx_proto.TensorProto.FLOAT, + data_type=dtype_mapping[str(dtype)], dims=best_scale.shape, vals=(1.0 / best_scale).flatten().tolist(), ) @@ -622,13 +643,14 @@ def apply_awq_clip(model, weight_config, absorb_pairs, output_dicts, num_bits, g return ratios -def prepare_inputs(model, n_samples, dataloader): +def prepare_inputs(model, n_samples, dataloader, providers): """Prepare inputs for weight only quantization. Args: model (ModelProto or ONNXModel): onnx model n_samples (int, optional): calibration sample number. -1 means all samples. dataloader (object): dataloader for calibration. + providers (list): providers to use Returns: inputs: prepared inputs. @@ -653,9 +675,9 @@ def prepare_inputs(model, n_samples, dataloader): ) session = ( - ort.InferenceSession(model.model.SerializeToString(), so, providers=["CPUExecutionProvider"]) + ort.InferenceSession(model.model.SerializeToString(), so, providers=providers) if not model.is_large_model - else ort.InferenceSession(model.model_path + "_augment.onnx", so, providers=["CPUExecutionProvider"]) + else ort.InferenceSession(model.model_path + "_augment.onnx", so, providers=providers) ) inputs_names = [i.name for i in session.get_inputs()] del session @@ -689,6 +711,7 @@ def awq_quantize( enable_auto_scale=True, enable_mse_search=True, accuracy_level=0, + providers=["CPUExecutionProvider"], ): """Quant the model with Activation-aware Weight quantization(AWQ) method. @@ -715,6 +738,7 @@ def awq_quantize( accuracy_level (int): accuracy level. Support 0 (unset), 1(fp32 compute type of jblas kernel), 2 (fp16 compute type of jblas kernel), 3 (bf16 compute type of jblas kernel), 4 (int8 compute type of jblas kernel) + providers (list): providers to use Returns: model: fake quantized ONNXModel @@ -724,7 +748,7 @@ def awq_quantize( full_ratio = {} if enable_mse_search: - inputs, so = prepare_inputs(model, n_samples, dataloader) + inputs, so = prepare_inputs(model, n_samples, dataloader, providers) del dataloader org_output = copy.deepcopy(model.model.graph.output) @@ -750,9 +774,9 @@ def awq_quantize( ) session = ( - ort.InferenceSession(model.model.SerializeToString(), so, providers=["CPUExecutionProvider"]) + ort.InferenceSession(model.model.SerializeToString(), so, providers=providers) if not model.is_large_model - else ort.InferenceSession(model.model_path + "_augment.onnx", so, providers=["CPUExecutionProvider"]) + else ort.InferenceSession(model.model_path + "_augment.onnx", so, providers=providers) ) for input_name in output_names: @@ -764,6 +788,7 @@ def awq_quantize( node.op_type in ["MatMul"] and weight_config.get(node.name, {}) != "fp32" and weight_config.get(node.name, {}).get("algorithm", "AWQ") == "AWQ" + and model.get_initializer(node.input[1]) is not None ): dump_pairs[parent.name].append(model.get_node(node.name)) @@ -801,7 +826,7 @@ def awq_quantize( model.remove_tensors_from_outputs(output_names) model.model.graph.output.MergeFrom(org_output) - model = rtn_quantize(model, weight_config, num_bits, group_size, scheme, full_ratio, accuracy_level) + model = rtn_quantize(model, weight_config, num_bits, group_size, scheme, full_ratio, accuracy_level, providers) return model @@ -890,7 +915,6 @@ def find_params(weight): scales = [] zps = [] - dtype = W.dtype shape = W.shape scale, zp = find_params(W) dead = np.diag(H) == 0 @@ -944,7 +968,7 @@ def find_params(weight): invperm = np.argsort(perm) Q = Q[invperm, :] - Q = np.reshape(Q, W.shape).astype(dtype) + Q = np.reshape(Q, W.shape) del W return Q @@ -963,6 +987,7 @@ def gptq_quantize( mse=False, perchannel=True, accuracy_level=0, + providers=["CPUExecutionProvider"], ): """Quant the model with GPTQ method. @@ -992,6 +1017,7 @@ def gptq_quantize( accuracy_level (int): accuracy level. Support 0 (unset), 1(fp32 compute type of jblas kernel), 2 (fp16 compute type of jblas kernel), 3 (bf16 compute type of jblas kernel), 4 (int8 compute type of jblas kernel) + providers (list): providers to use Returns: model: fake quantized ONNXModel @@ -1000,7 +1026,7 @@ def gptq_quantize( base_dir = os.path.dirname(model.model_path) if model.model_path is not None else "" output_dicts = {} - inputs, so = prepare_inputs(model, n_samples, dataloader) + inputs, so = prepare_inputs(model, n_samples, dataloader, providers) del dataloader org_output = copy.deepcopy(model.model.graph.output) model.remove_tensors_from_outputs([i.name for i in org_output]) @@ -1024,9 +1050,9 @@ def gptq_quantize( ) session = ( - ort.InferenceSession(model.model.SerializeToString(), so, providers=["CPUExecutionProvider"]) + ort.InferenceSession(model.model.SerializeToString(), so, providers=providers) if not model.is_large_model - else ort.InferenceSession(model.model_path + "_augment.onnx", so, providers=["CPUExecutionProvider"]) + else ort.InferenceSession(model.model_path + "_augment.onnx", so, providers=providers) ) new_nodes = [] @@ -1041,6 +1067,7 @@ def gptq_quantize( node.op_type in ["MatMul"] and weight_config.get(node.name, {}) != "fp32" and weight_config.get(node.name, {}).get("algorithm", "GPTQ") == "GPTQ" + and model.get_initializer(node.input[1]) is not None ): weight = numpy_helper.to_array( model.get_initializer(model.get_node(node.name).input[1]), base_dir @@ -1075,6 +1102,7 @@ def gptq_quantize( group_size = weight_config[node.name]["group_size"] scheme = weight_config[node.name]["scheme"] group_size = group_size if group_size != -1 else weight.shape[0] + dtype = weight.dtype q_weight = gptq( weight, @@ -1091,11 +1119,17 @@ def gptq_quantize( weight_tensor = model.get_initializer(node.input[1]) init_share_num = model.get_initializer_share_num(node.input[1]) - if (Version(ort.__version__) > ONNXRT1161_VERSION and num_bits == 4) or ( + + satisfy_MatMulNBits_condition = Version(ort.__version__) > ONNXRT1161_VERSION and num_bits == 4 + satisfy_MatMulFpQ4_condition = ( Version(ort.__version__) >= ONNXRT116_VERSION and num_bits == 4 and group_size == 32 + ) + if ("CUDAExecutionProvider" in providers and satisfy_MatMulNBits_condition) or ( + "CUDAExecutionProvider" not in providers + and (satisfy_MatMulFpQ4_condition or satisfy_MatMulNBits_condition) ): # pragma: no cover - # MatMulFpQ4 support 4 bits and 32 group_size with ort 1.16.0 and 1.16.1 versions - # MatMulNBits supports 4 bits and 2^n group_size with ort > 1.16.1 + # MatMulFpQ4 support 4 bits and 32 group_size with ort 1.16.0 and 1.16.1 versions, supported by CPU EP + # MatMulNBits supports 4 bits and 2^n group_size with ort > 1.16.1, supported by CPU EP AND CUDA EP org_shape = weight.shape k_blocks = (org_shape[0] + group_size - 1) // group_size q_weight = pad_tensor(q_weight, group_size, k_blocks) @@ -1107,7 +1141,7 @@ def gptq_quantize( group_size=group_size, k_blocks=k_blocks, q_weight=q_weight.astype("uint8"), - scale=scale, + scale=scale.astype(dtype), zero_point=zp if scheme == "asym" else None, accuracy_level=accuracy_level, ) @@ -1118,9 +1152,9 @@ def gptq_quantize( else: q_weight_tensor = onnx.helper.make_tensor( name=node.input[1] + "_Q{}G{}".format(str(num_bits), str(group_size)), - data_type=1, + data_type=dtype_mapping[str(dtype)], dims=q_weight.shape, - vals=q_weight.tobytes(), + vals=q_weight.astype(dtype).tobytes(), raw=True, ) model.add_initializer(q_weight_tensor) diff --git a/neural_compressor/model/model.py b/neural_compressor/model/model.py index 2c9c6358d97..06362d2d0d8 100644 --- a/neural_compressor/model/model.py +++ b/neural_compressor/model/model.py @@ -84,9 +84,9 @@ def _is_onnxruntime(model): so.register_custom_ops_library(get_library_path()) if isinstance(model, str): - ort.InferenceSession(model, so, providers=["CPUExecutionProvider"]) + ort.InferenceSession(model, so, providers=ort.get_available_providers()) else: - ort.InferenceSession(model.SerializeToString(), so, providers=["CPUExecutionProvider"]) + ort.InferenceSession(model.SerializeToString(), so, providers=ort.get_available_providers()) except Exception as e: # pragma: no cover if "Message onnx.ModelProto exceeds maximum protobuf size of 2GB" in str(e): logger.warning("Please use model path instead of onnx model object to quantize") diff --git a/test/adaptor/onnxrt_adaptor/test_weight_only_adaptor.py b/test/adaptor/onnxrt_adaptor/test_weight_only_adaptor.py index aeb6ba14c83..361f4aae75b 100644 --- a/test/adaptor/onnxrt_adaptor/test_weight_only_adaptor.py +++ b/test/adaptor/onnxrt_adaptor/test_weight_only_adaptor.py @@ -7,6 +7,7 @@ import numpy as np import onnx import onnxruntime as ort +from packaging.version import Version from transformers import AutoTokenizer from neural_compressor import PostTrainingQuantConfig, quantization @@ -48,6 +49,14 @@ def setUpClass(self): p.communicate() self.gptj_model = onnx.load("gptj/decoder_model.onnx") + self.gptj_fp16_model = None + if "CUDAExecutionProvider" in ort.get_available_providers(): + cmd = "optimum-cli export onnx --model hf-internal-testing/tiny-random-gptj --task text-generation --legacy --fp16 --device cuda gptj_fp16/" + p = subprocess.Popen( + cmd, preexec_fn=os.setsid, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True + ) # nosec + p.communicate() + self.gptj_fp16_model = onnx.load("gptj_fp16/decoder_model.onnx") self.gptj_dataloader = DummyNLPDataloader("hf-internal-testing/tiny-random-gptj") cmd = ( @@ -65,8 +74,102 @@ def setUpClass(self): def tearDownClass(self): shutil.rmtree("nc_workspace", ignore_errors=True) shutil.rmtree("gptj", ignore_errors=True) + shutil.rmtree("gptj_fp16", ignore_errors=True) shutil.rmtree("tiny-llama", ignore_errors=True) + @unittest.skipIf("CUDAExecutionProvider" not in ort.get_available_providers(), "Skip cuda woq test") + def test_RTN_quant_with_woq_op(self): + conf = PostTrainingQuantConfig( + approach="weight_only", + device="gpu", + backend="onnxrt_cuda_ep", + op_type_dict={ + ".*": { # re.match + "weight": { + "bits": 4, + "group_size": 32, + "scheme": "sym", + "algorithm": "RTN", + }, + }, + }, + ) + # test fp16 model + q_fp16_model = quantization.fit(self.gptj_fp16_model, conf) + + for data, _ in self.gptj_dataloader: + q_out = Inference(q_fp16_model.model, data) + org_out = Inference(self.gptj_fp16_model, data) + for q, org in zip(q_out, org_out): + self.assertTrue((np.abs(q_out[0] - org_out[0]) < 0.5).all()) + if Version(ort.__version__) > Version("1.16.1"): + scale_tensor = [i for i in q_fp16_model.initializer() if i.name.endswith("_scale")] + self.assertTrue(len(scale_tensor) > 0) + self.assertEqual(scale_tensor[0].data_type, 10) + self.assertTrue("MatMulNBits" in set([node.op_type for node in q_fp16_model.model.graph.node])) + + @unittest.skipIf("CUDAExecutionProvider" not in ort.get_available_providers(), "Skip cuda woq test") + def test_AWQ_quant_with_woq_op(self): + conf = PostTrainingQuantConfig( + approach="weight_only", + device="gpu", + backend="onnxrt_cuda_ep", + op_type_dict={ + ".*": { # re.match + "weight": { + "bits": 4, + "group_size": 32, + "scheme": "sym", + "algorithm": "AWQ", + }, + }, + }, + recipes={ + "awq_args": {"enable_auto_scale": True, "enable_mse_search": True}, + }, + ) + # test fp16 model + q_fp16_model = quantization.fit(self.gptj_fp16_model, conf, calib_dataloader=self.gptj_dataloader) + for data, _ in self.gptj_dataloader: + q_out = Inference(q_fp16_model.model, data) + org_out = Inference(self.gptj_fp16_model, data) + for q, org in zip(q_out, org_out): + self.assertTrue((np.abs(q_out[0] - org_out[0]) < 0.5).all()) + if Version(ort.__version__) > Version("1.16.1"): + scale_tensor = [i for i in q_fp16_model.initializer() if i.name.endswith("_scale")] + self.assertTrue(len(scale_tensor) > 0) + self.assertEqual(scale_tensor[0].data_type, 10) + self.assertTrue("MatMulNBits" in set([node.op_type for node in q_fp16_model.model.graph.node])) + + @unittest.skipIf("CUDAExecutionProvider" not in ort.get_available_providers(), "Skip cuda woq test") + def test_GPTQ_quant_with_woq_op(self): + conf = PostTrainingQuantConfig( + approach="weight_only", + device="gpu", + backend="onnxrt_cuda_ep", + op_type_dict={ + ".*": { # re.match + "weight": { + "bits": 4, + "group_size": 32, + "scheme": "sym", + "algorithm": "GPTQ", + }, + }, + }, + ) + q_fp16_model = quantization.fit(self.gptj_fp16_model, conf, calib_dataloader=self.gptj_dataloader) + for data, _ in self.gptj_dataloader: + q_out = Inference(q_fp16_model.model, data) + org_out = Inference(self.gptj_fp16_model, data) + for q, org in zip(q_out, org_out): + self.assertTrue((np.abs(q_out[0] - org_out[0]) < 0.5).all()) + if Version(ort.__version__) > Version("1.16.1"): + scale_tensor = [i for i in q_fp16_model.initializer() if i.name.endswith("_scale")] + self.assertTrue(len(scale_tensor) > 0) + self.assertEqual(scale_tensor[0].data_type, 10) + self.assertTrue("MatMulNBits" in set([node.op_type for node in q_fp16_model.model.graph.node])) + def test_RTN_quant(self): conf = PostTrainingQuantConfig( approach="weight_only",