diff --git a/.azure-pipelines/scripts/ut/run_3x_pt.sh b/.azure-pipelines/scripts/ut/run_3x_pt.sh index b41e4dd3811..7a5c131144e 100644 --- a/.azure-pipelines/scripts/ut/run_3x_pt.sh +++ b/.azure-pipelines/scripts/ut/run_3x_pt.sh @@ -6,6 +6,7 @@ echo "${test_case}" # install requirements echo "set up UT env..." pip install -r /neural-compressor/requirements_pt.txt +pip install transformers pip install coverage pip install pytest pip list diff --git a/.azure-pipelines/ut-3x-pt.yml b/.azure-pipelines/ut-3x-pt.yml index 6b6b4776fd0..d48619c01c6 100644 --- a/.azure-pipelines/ut-3x-pt.yml +++ b/.azure-pipelines/ut-3x-pt.yml @@ -14,7 +14,6 @@ pr: - setup.py - requirements.txt - requirements_pt.txt - - .azure-pipelines/scripts/ut pool: ICX-16C diff --git a/.azure-pipelines/ut-basic-no-cover.yml b/.azure-pipelines/ut-basic-no-cover.yml index 3deef2e05a1..e0e3bb3ad7f 100644 --- a/.azure-pipelines/ut-basic-no-cover.yml +++ b/.azure-pipelines/ut-basic-no-cover.yml @@ -12,7 +12,6 @@ pr: - test - setup.py - requirements.txt - - .azure-pipelines/scripts/ut exclude: - test/neural_coder - test/3x diff --git a/.azure-pipelines/ut-basic.yml b/.azure-pipelines/ut-basic.yml index 5312b5bf6c9..0f22d7f467a 100644 --- a/.azure-pipelines/ut-basic.yml +++ b/.azure-pipelines/ut-basic.yml @@ -12,7 +12,6 @@ pr: - test - setup.py - requirements.txt - - .azure-pipelines/scripts/ut exclude: - test/neural_coder - test/3x diff --git a/neural_compressor/common/__init__.py b/neural_compressor/common/__init__.py index 8989ae9d722..67a891db75e 100644 --- a/neural_compressor/common/__init__.py +++ b/neural_compressor/common/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from neural_compressor.common.logger import level, log, info, debug, warn, warning, error, fatal diff --git a/neural_compressor/common/base_config.py b/neural_compressor/common/base_config.py index 02dd897b37f..51a3f70c1dc 100644 --- a/neural_compressor/common/base_config.py +++ b/neural_compressor/common/base_config.py @@ -19,10 +19,14 @@ import json from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Optional, Union +from collections import OrderedDict +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +from neural_compressor.common.logger import Logger +from neural_compressor.common.utility import BASE_CONFIG, COMPOSABLE_CONFIG, GLOBAL, LOCAL + +logger = Logger().get_logger() -from neural_compressor.common.utility import BASE_CONFIG, GLOBAL, OPERATOR_NAME -from neural_compressor.utils import logger # Dictionary to store registered configurations registered_configs = {} @@ -57,20 +61,36 @@ class BaseConfig(ABC): name = BASE_CONFIG def __init__(self) -> None: - self.global_config: Optional[BaseConfig] = None + self._global_config: Optional[BaseConfig] = None # For PyTorch, operator_type is the collective name for module type and functional operation type, # for example, `torch.nn.Linear`, and `torch.nn.functional.linear`. - self.operator_type_config: Dict[Union[str, Callable], Optional[BaseConfig]] = {} - self.operator_name_config: Dict[str, Optional[BaseConfig]] = {} - - def set_operator_name(self, operator_name: str, config: BaseConfig) -> BaseConfig: - self.operator_name_config[operator_name] = config - return self - - def _set_operator_type(self, operator_type: Union[str, Callable], config: BaseConfig) -> BaseConfig: - # TODO (Yi), clean the usage - # hide it from user, as we can use set_operator_name with regular expression to convert its functionality - self.operator_type_config[operator_type] = config + # local config is the collections of operator_type configs and operator configs + self._local_config: Dict[str, Optional[BaseConfig]] = {} + + @property + def global_config(self): + if self._global_config is None: + self._global_config = self.__class__(**self.to_dict()) + return self._global_config + + @global_config.setter + def global_config(self, config): + self._global_config = config + + @property + def local_config(self): + return self._local_config + + @local_config.setter + def local_config(self, config): + self._local_config = config + + def set_local(self, operator_name: str, config: BaseConfig) -> BaseConfig: + if operator_name in self.local_config: + logger.warning("The configuration for %s has already been set, update it.", operator_name) + if self.global_config is None: + self.global_config = self.__class__(**self.to_dict()) + self.local_config[operator_name] = config return self def to_dict(self, params_list=[], operator2str=None): @@ -78,10 +98,10 @@ def to_dict(self, params_list=[], operator2str=None): global_config = {} for param in params_list: global_config[param] = getattr(self, param) - if bool(self.operator_name_config): - result[OPERATOR_NAME] = {} - for op_name, config in self.operator_name_config.items(): - result[OPERATOR_NAME][op_name] = config.to_dict() + if bool(self.local_config): + result[LOCAL] = {} + for op_name, config in self.local_config.items(): + result[LOCAL][op_name] = config.to_dict() result[GLOBAL] = global_config else: result = global_config @@ -99,10 +119,10 @@ def from_dict(cls, config_dict, str2operator=None): The constructed config. """ config = cls(**config_dict.get(GLOBAL, {})) - operator_config = config_dict.get(OPERATOR_NAME, {}) + operator_config = config_dict.get(LOCAL, {}) if operator_config: for op_name, op_config in operator_config.items(): - config.set_operator_name(op_name, cls(**op_config)) + config.set_local(op_name, cls(**op_config)) return config @classmethod @@ -120,7 +140,7 @@ def to_json_file(self, filename): config_dict = self.to_dict() with open(filename, "w", encoding="utf-8") as file: json.dump(config_dict, file, indent=4) - logger.info(f"Dump the config into {filename}") + logger.info("Dump the config into %s.", filename) def to_json_string(self, use_diff: bool = False) -> str: """Serializes this instance to a JSON string. @@ -137,7 +157,7 @@ def to_json_string(self, use_diff: bool = False) -> str: config_dict = self.to_diff_dict(self) else: config_dict = self.to_dict() - return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + return json.dumps(config_dict, indent=2) + "\n" def __repr__(self) -> str: return f"{self.__class__.__name__} {self.to_json_string()}" @@ -154,10 +174,82 @@ def validate(self, user_config: BaseConfig): pass def __add__(self, other: BaseConfig) -> BaseConfig: - # TODO(Yi) implement config add, like RTNWeightOnlyQuantConfig() + GPTQWeightOnlyQuantConfig() - pass + if isinstance(other, type(self)): + for op_name, config in other.local_config.items(): + self.set_local(op_name, config) + return self + else: + return ComposableConfig(configs=[self, other]) + + def _get_op_name_op_type_config(self): + op_type_config_dict = dict() + op_name_config_dict = dict() + for name, config in self.local_config.items(): + if self._is_op_type(name): + op_type_config_dict[name] = config + else: + op_name_config_dict[name] = config + return op_type_config_dict, op_name_config_dict + + def to_config_mapping( + self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None + ) -> OrderedDict[Union[str, Callable], OrderedDict[str, BaseConfig]]: + config_mapping = OrderedDict() + if config_list is None: + config_list = [self] + for config in config_list: + global_config = config.global_config + op_type_config_dict, op_name_config_dict = config._get_op_name_op_type_config() + for op_name, op_type in model_info: + config_mapping.setdefault(op_type, OrderedDict())[op_name] = global_config + if op_type in op_type_config_dict: + config_mapping[op_type][op_name] = op_name_config_dict[op_type] + if op_name in op_name_config_dict: + config_mapping[op_type][op_name] = op_name_config_dict[op_name] + return config_mapping @staticmethod def _is_op_type(name: str) -> bool: # TODO (Yi), ort and tf need override it return not isinstance(name, str) + + +class ComposableConfig(BaseConfig): + name = COMPOSABLE_CONFIG + + def __init__(self, configs: List[BaseConfig]) -> None: + self.config_list = configs + + def __add__(self, other: BaseConfig) -> BaseConfig: + if isinstance(other, type(self)): + self.config_list.extend(other.config_list) + else: + self.config_list.append(other) + return self + + def to_dict(self, params_list=[], operator2str=None): + result = {} + for config in self.config_list: + result[config.name] = config.to_dict() + return result + + @classmethod + def from_dict(cls, config_dict, str2operator=None): + # TODO(Yi) + pass + + def to_json_string(self, use_diff: bool = False) -> str: + return json.dumps(self.to_dict(), indent=2) + "\n" + + def __repr__(self) -> str: + return f"{self.__class__.__name__} {self.to_json_string()}" + + def to_config_mapping( + self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None + ) -> OrderedDict[str, BaseConfig]: + return super().to_config_mapping(self.config_list, model_info) + + @classmethod + def register_supported_configs(cls): + """Add all supported configs.""" + raise NotImplementedError diff --git a/neural_compressor/common/logger.py b/neural_compressor/common/logger.py new file mode 100644 index 00000000000..dd8cbd5388a --- /dev/null +++ b/neural_compressor/common/logger.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Logger: handles logging functionalities.""" + +import logging +import os + + +class Logger(object): + """Logger class.""" + + __instance = None + + def __new__(cls): + """Create a singleton Logger instance.""" + if Logger.__instance is None: + Logger.__instance = object.__new__(cls) + Logger.__instance._log() + return Logger.__instance + + def _log(self): + """Setup the logger format and handler.""" + LOGLEVEL = os.environ.get("LOGLEVEL", "INFO").upper() + self._logger = logging.getLogger("neural_compressor") + self._logger.handlers.clear() + self._logger.setLevel(LOGLEVEL) + formatter = logging.Formatter( + "%(asctime)s [%(levelname)s][%(filename)s:%(lineno)d] %(message)s", "%Y-%m-%d %H:%M:%S" + ) + streamHandler = logging.StreamHandler() + streamHandler.setFormatter(formatter) + self._logger.addHandler(streamHandler) + self._logger.propagate = False + + def get_logger(self): + """Get the logger.""" + return self._logger + + +def _pretty_dict(value, indent=0): + """Make the logger dict pretty.""" + prefix = "\n" + " " * (indent + 4) + if isinstance(value, dict): + items = [prefix + repr(key) + ": " + _pretty_dict(value[key], indent + 4) for key in value] + return "{%s}" % (",".join(items) + "\n" + " " * indent) + elif isinstance(value, list): + items = [prefix + _pretty_dict(item, indent + 4) for item in value] + return "[%s]" % (",".join(items) + "\n" + " " * indent) + elif isinstance(value, tuple): + items = [prefix + _pretty_dict(item, indent + 4) for item in value] + return "(%s)" % (",".join(items) + "\n" + " " * indent) + else: + return repr(value) + + +level = Logger().get_logger().level +DEBUG = logging.DEBUG + + +def log(level, msg, *args, **kwargs): + """Output log with the level as a parameter.""" + if isinstance(msg, dict): + for _, line in enumerate(_pretty_dict(msg).split("\n")): + Logger().get_logger().log(level, line, *args, **kwargs) + else: + Logger().get_logger().log(level, msg, *args, **kwargs) + + +def debug(msg, *args, **kwargs): + """Output log with the debug level.""" + if isinstance(msg, dict): + for _, line in enumerate(_pretty_dict(msg).split("\n")): + Logger().get_logger().debug(line, *args, **kwargs) + else: + Logger().get_logger().debug(msg, *args, **kwargs) + + +def error(msg, *args, **kwargs): + """Output log with the error level.""" + if isinstance(msg, dict): + for _, line in enumerate(_pretty_dict(msg).split("\n")): + Logger().get_logger().error(line, *args, **kwargs) + else: + Logger().get_logger().error(msg, *args, **kwargs) + + +def fatal(msg, *args, **kwargs): + """Output log with the fatal level.""" + if isinstance(msg, dict): + for _, line in enumerate(_pretty_dict(msg).split("\n")): + Logger().get_logger().fatal(line, *args, **kwargs) + else: + Logger().get_logger().fatal(msg, *args, **kwargs) + + +def info(msg, *args, **kwargs): + """Output log with the info level.""" + if isinstance(msg, dict): + for _, line in enumerate(_pretty_dict(msg).split("\n")): + Logger().get_logger().info(line, *args, **kwargs) + else: + Logger().get_logger().info(msg, *args, **kwargs) + + +def warn(msg, *args, **kwargs): + """Output log with the warning level.""" + if isinstance(msg, dict): + for _, line in enumerate(_pretty_dict(msg).split("\n")): + Logger().get_logger().warning(line, *args, **kwargs) + else: + Logger().get_logger().warning(msg, *args, **kwargs) + + +def warning(msg, *args, **kwargs): + """Output log with the warning level (Alias of the method warn).""" + if isinstance(msg, dict): + for _, line in enumerate(_pretty_dict(msg).split("\n")): + Logger().get_logger().warning(line, *args, **kwargs) + else: + Logger().get_logger().warning(msg, *args, **kwargs) diff --git a/neural_compressor/common/utility.py b/neural_compressor/common/utility.py index d1e05e18e53..51b37092033 100644 --- a/neural_compressor/common/utility.py +++ b/neural_compressor/common/utility.py @@ -20,8 +20,10 @@ # constants for configs GLOBAL = "global" -OPERATOR_NAME = "operator_name" +LOCAL = "local" # config name BASE_CONFIG = "base_config" +COMPOSABLE_CONFIG = "composable_config" RTN_WEIGHT_ONLY_QUANT = "rtn_weight_only_quant" +DUMMY_CONFIG = "dummy_config" diff --git a/neural_compressor/torch/__init__.py b/neural_compressor/torch/__init__.py index 6722679abbe..b8606e0b7f8 100644 --- a/neural_compressor/torch/__init__.py +++ b/neural_compressor/torch/__init__.py @@ -15,4 +15,10 @@ from neural_compressor.torch.utils import register_algo from neural_compressor.torch.algorithms import rtn_quantize_entry -from neural_compressor.torch.quantization import quantize, RTNWeightQuantConfig, get_default_rtn_config +from neural_compressor.torch.quantization import ( + quantize, + RTNWeightQuantConfig, + get_default_rtn_config, + DummyConfig, + get_default_dummy_config, +) diff --git a/neural_compressor/torch/algorithms/rtn.py b/neural_compressor/torch/algorithms/rtn.py new file mode 100644 index 00000000000..1c0071d99e9 --- /dev/null +++ b/neural_compressor/torch/algorithms/rtn.py @@ -0,0 +1,660 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 MIT HAN Lab +# This source code is licensed under the MIT license +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +from torch.nn import functional as F + +from neural_compressor.common.logger import DEBUG, Logger, level +from neural_compressor.torch.utils import set_module + +logger = Logger().get_logger() + + +NF4 = [ + -1.0, + -0.6961928009986877, + -0.5250730514526367, + -0.39491748809814453, + -0.28444138169288635, + -0.18477343022823334, + -0.09105003625154495, + 0.0, + 0.07958029955625534, + 0.16093020141124725, + 0.24611230194568634, + 0.33791524171829224, + 0.44070982933044434, + 0.5626170039176941, + 0.7229568362236023, + 1.0, +] +FP4_BNB = [-12.0, -8.0, -6.0, -4.0, -3.0, -2.0, -0.0625, 0, 0.0625, 2.0, 3.0, 4.0, 6.0, 8.0, 12.0] +FP4_E2M1 = [-6.0, -4.0, -3.0, -2.0, -1.5, -1.0, -0.0625, 0, 0.0625, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0] + +# the order is the same as float list, bit value range is [-7, 7] +# 1111 = -1, 1110 = -2, 1101= -3, ... + +NF4_BIT = [7, 1, 2, 3, 4, 5, 6, 0, -8, -7, -6, -5, -4, -3, -2, -1] +FP4_BNB_BIT = [-5, -6, -3, -4, -1, -2, -7, 0, 1, 6, 7, 4, 5, 2, 3] +FP4_E2M1_BIT = [-1, -2, -3, -4, -5, -6, -7, 0, 1, 2, 3, 4, 5, 6, 7] + +FLOAT_MAPPING = {"nf4": NF4, "fp4": FP4_BNB, "fp4_e2m1_bnb": FP4_BNB, "fp4_e2m1": FP4_E2M1} +INT_MAPPING = {"nf4": NF4_BIT, "fp4": FP4_BNB_BIT, "fp4_e2m1_bnb": FP4_BNB_BIT, "fp4_e2m1": FP4_E2M1_BIT} + + +def quantize_4bit(tensor, quantile=1.0, data_type="nf4", return_int=False): + """Quantize tensor to NF4/FP4 data type. + + Args: + tensor: input tensor + quantile (float, optional): percentile of clip. Defaults to 1.0. + data_type (str, optional): data type. Defaults to 'nf4'. + return_int (bool, optional): whether return int data. Defaults to False. + + Returns: + q_tensor: fake quantized tensor + """ + assert data_type in FLOAT_MAPPING, "unexpected data type." + allow_data = FLOAT_MAPPING[data_type] + allow_data_bit = INT_MAPPING[data_type] + # get scale and update tensor + scale = tensor.abs().max(1)[0] * quantile / max(allow_data) + scale.unsqueeze_(dim=-1) + tensor = tensor / scale + mid_data = [(allow_data[i] + allow_data[i + 1]) / 2 for i in range(len(allow_data) - 1)] + q_tensor = torch.zeros_like(tensor) + for i in range(len(allow_data)): + data = allow_data_bit[i] if return_int else allow_data[i] + if i == 0: + q_tensor += torch.where(tensor <= mid_data[i], data, 0) + elif i == len(allow_data) - 1: + q_tensor += torch.where(tensor > mid_data[i - 1], data, 0) + else: + q_tensor += torch.where((mid_data[i - 1] < tensor) & (tensor <= mid_data[i]), data, 0) + if return_int: + return q_tensor.type(torch.int8), scale.type(torch.float), None + return q_tensor * scale + + +def qdq_weight_asym(weight, num_bits=4, quantile=1.0, return_int=False): + """Quant and dequant tensor with asym schema. + + Args: + weight: input weight + num_bits (int, optional): num_bits. Defaults to 4. + quantile (float, optional): percentile of clip. Defaults to 1.0. + return_int (bool, optional): Choose return fp32 or int8/uint8 data. + Defaults to False. + + Returns: + output: qdq weight + """ + maxq = torch.tensor(2**num_bits - 1) + zeros = torch.zeros(weight.shape[0], device=weight.device) + wmin = torch.minimum(weight.min(1)[0], zeros) + wmax = torch.maximum(weight.max(1)[0], zeros) + wmin = wmin * quantile + wmax = wmax * quantile + tmp = (wmin == 0) & (wmax == 0) + wmin[tmp] = -1 + wmax[tmp] = +1 + scale = (wmax - wmin) / maxq + zp = torch.round(-wmin / scale) + scale.unsqueeze_(dim=-1) + zp.unsqueeze_(dim=-1) + q = torch.clamp(torch.round(weight / scale) + zp, 0, maxq) + if return_int: + return q.type(torch.uint8), scale.type(torch.float), zp.type(torch.uint8) + return scale * (q - zp) + + +def qdq_weight_sym(weight, num_bits=4, quantile=1.0, return_int=False, full_range=False): + """Quant and dequant tensor with sym schema. + + Args: + weight : input weight + num_bits (int, optional): num_bits. Defaults to 4. + quantile (float, optional): percentile of clip. Defaults to 1.0. + return_int (bool, optional): Choose return fp32 or int8/uint8 data. + Defaults to False. + full_range (bool, optional): Choose sym range whether use -2**(bits-1). + For example: 4 bit + scale = amax / 8 if full_range else amax / 7 + If True, scale = -scale if abs(min)> abs(max) else scale + Defaults to False. + + Returns: + output: qdq weight + """ + # assert num_bits > 1, "symmetric scheme only supports num_bits > 1" + maxq = torch.tensor(2 ** (num_bits - 1) - 1).to(weight.device) + minq = torch.tensor(-(2 ** (num_bits - 1))).to(weight.device) + if num_bits == 1: # pragma: no cover + maxq = torch.tensor(2 ** (num_bits - 1)) + minq = torch.tensor(2 ** (num_bits - 1) - 1) + max_val = torch.max(weight, 1)[0] + min_val = torch.min(weight, 1)[0] + flip_flag = torch.abs(max_val) > torch.abs(min_val) + wmax = torch.max(torch.abs(max_val), torch.abs(min_val)) + wmax = wmax * quantile + tmp = wmax == 0 + wmax[tmp] = +1 + if full_range: + # use -8, 8 to make sure amax is not changed after fake quant + scale = wmax / (-minq) + tmp = scale * flip_flag.int() + scale -= 2 * tmp # set negetive scale with flip_flag + else: + scale = wmax / maxq + scale.unsqueeze_(dim=-1) + q = torch.clamp(torch.round(weight / scale), minq, maxq) + if return_int: + return q.type(torch.int8), scale.type(torch.float), None + return scale * q + + +def qdq_weight_actor(weight, num_bits, scheme, quantile=1.0, data_type="int", return_int=False, full_range=False): + """Quant and dequant tensor per channel. + + Args: + weight : input weight + num_bits (int, optional): num_bits. Defaults to 4. + quantile (float, optional): percentile of clip. Defaults to 1.0. + data_type (str, optional): select from int, nf4, fp4. Defaults to int. + return_int (bool, optional): Choose return fp32 or int8/uint8 data. + Defaults to False. + full_range (bool, optional): Choose sym range whether use -2**(bits-1). + + Returns: + output: qdq weight + """ + assert num_bits > 0, "num_bits should be larger than 0" + if "int" not in data_type and num_bits == 4: + return quantize_4bit(weight, quantile=quantile, data_type=data_type, return_int=return_int) + if scheme == "sym": + return qdq_weight_sym(weight, num_bits, quantile, return_int, full_range) + else: + return qdq_weight_asym(weight, num_bits, quantile, return_int) + + +def quant_weight( + weight, num_bits=4, group_size=-1, scheme="asym", quantile=1.0, data_type="int", return_int=False, full_range=False +): + """Quant and dequant tensor with group size. + + Args: + weight: input weight + num_bits (int, optional): num_bits. Defaults to 4. + group_size (int, optional): how many elements share one scale/zp. Defaults to -1. + scheme (str, optional): sym or asym. Defaults to "asym". + quantile (float, optional): percentile of clip. Defaults to 1.0. + data_type (str, optional): select from int, nf4, fp4. Defaults to int. + return_int (bool, optional): Choose return fp32 or int8/uint8 data. + Defaults to False. + full_range (bool, optional): Choose sym range whether use -2**(bits-1). + + Returns: + output: qdq weight. + """ + if num_bits <= 0: # pragma: no cover + return weight + if group_size == -1 or weight.shape[1] < group_size: + return qdq_weight_actor( + weight, + num_bits, + scheme=scheme, + quantile=quantile, + return_int=return_int, + full_range=full_range, + data_type=data_type, + ) + orig_shape = weight.shape + if weight.shape[1] % group_size == 0: + weight = weight.reshape(-1, group_size) + if return_int: + weight, scale, zp = qdq_weight_actor( + weight, + num_bits, + scheme=scheme, + quantile=quantile, + return_int=True, + full_range=full_range, + data_type=data_type, + ) + weight = weight.reshape(orig_shape) + scale = scale.reshape(orig_shape[0], -1) + if zp is not None: + zp = zp.reshape(orig_shape[0], -1) + return weight, scale, zp + else: + weight = qdq_weight_actor( + weight, num_bits, scheme=scheme, data_type=data_type, quantile=quantile, full_range=full_range + ) + return weight.reshape(orig_shape) + else: + split_index = weight.shape[1] // group_size * group_size + weight1 = weight[:, :split_index] + weight1 = weight1.reshape(-1, group_size) + if return_int: + weight1, scale1, zp1 = qdq_weight_actor( + weight1, + num_bits, + scheme=scheme, + data_type=data_type, + quantile=quantile, + return_int=True, + full_range=full_range, + ) + scale1 = scale1.reshape(orig_shape[0], -1) + if zp1 is not None: + zp1 = zp1.reshape(orig_shape[0], -1) + else: + weight1 = qdq_weight_actor( + weight1, num_bits, scheme=scheme, quantile=quantile, data_type=data_type, full_range=full_range + ) + weight1 = weight1.reshape(orig_shape[0], split_index) + weight2 = weight[:, split_index:] + if return_int: + weight2, scale2, zp2 = qdq_weight_actor( + weight2, + num_bits, + scheme=scheme, + data_type=data_type, + quantile=quantile, + return_int=True, + full_range=full_range, + ) + weight = torch.cat([weight1, weight2], dim=1) + scale = torch.cat([scale1, scale2], dim=1) + if zp2 is not None: + zp = torch.cat([zp1, zp2], dim=1) + else: + zp = None + return weight, scale, zp + else: + weight2 = qdq_weight_actor( + weight2, num_bits, scheme=scheme, data_type=data_type, quantile=quantile, full_range=full_range + ) + weight = torch.cat([weight1, weight2], dim=1) + return weight + + +def search_clip(m, num_bits=4, group_size=32, scheme="asym", data_type="int", enable_full_range=False): + """Search best clip range of each linears in current block. + + Args: + m (torch.nn.Module): torch module. + num_bits (int, optional): num bits. + group_size (int, optional): how many elements share one scale/zp. + scheme (str, optional): sym or asym. + data_type (str, optional): select from int, nf4, fp4. Defaults to int. + enable_full_range (bool, optional): Choose sym range whether use -2**(bits-1). + + Returns: + best_clip_ratio (float): best percentile of clip + """ + org_weight = m.weight.data + logger.info("Searching the best clip range with RTN algorithm") + best_error = float("inf") + best_clip_ratio = None + n_grid = 200 + max_shrink = 0.2 + history = [] + for i_s in range(int(max_shrink * n_grid)): + ratio = 1 - i_s / n_grid # 1, 0.805-1.0 + cur_weight = quant_weight( + m.weight.data, + num_bits=num_bits, + group_size=group_size, + scheme=scheme, + data_type=data_type, + full_range=enable_full_range, + quantile=ratio, + ) + loss = (org_weight - cur_weight).float().pow(2).mean().item() + history.append(loss) + is_best = loss < best_error + if is_best: + best_error = loss + best_clip_ratio = ratio + logger.debug("The loss history of different clip range:{}".format(history)) + logger.debug("The best clip ratio is {}".format(best_clip_ratio)) + return best_clip_ratio + + +import math + + +class WeightOnlyLinear(torch.nn.Module): + def __init__( + self, + in_features, + out_features, + bits, + groupsize, + dtype="int", + zp=False, + bias=False, + scale_dtype=torch.float32, + compression_dtype=torch.int32, + compression_dim=1, + gptq_perm=False, + device="cpu", + ): + super().__init__() + self.dtype = dtype + if "int" not in self.dtype: # for nf4, fp4 + float_list = FLOAT_MAPPING[self.dtype] + int_list = INT_MAPPING[self.dtype] + self.int2float_mapping = {} + for k, v in zip(int_list, float_list): + self.int2float_mapping[k] = v + self.device = device + self.in_features = in_features + self.out_features = out_features + self.bits = bits + self.groupsize = groupsize if groupsize != -1 else in_features + self.compression_dim = compression_dim + assert compression_dtype in [ + torch.int8, + torch.int16, + torch.int32, + torch.int64, + ], "Only support torch.int8|16|32|64 as compressed dtype." + dtype_bits_mapping = {torch.int8: 8, torch.int16: 16, torch.int32: 32, torch.int64: 64} + self.compress_bits = dtype_bits_mapping[compression_dtype] + self.n_pack = self.compress_bits // self.bits + self.compressed_dtype = compression_dtype + self.float_type = scale_dtype + # K is input channel, N is output channel + assert compression_dim in [0, 1], ( + "Only support 0 or 1 as compression dimension, " + "0 is output channel, 1 is input channel." + ) + self.register_buffer( + "scale", + torch.zeros( + (out_features, math.ceil(in_features / self.groupsize)), + dtype=self.float_type, + ).to(device), + ) + if compression_dim == 1: + self.register_buffer( + "packed_weight", + torch.zeros( + (out_features, math.ceil(in_features / self.n_pack)), + dtype=self.compressed_dtype, + ).to(device), + ) + if zp: + self.register_buffer( + "packed_zp", + torch.zeros( + (self.out_features, math.ceil(self.in_features / self.groupsize / self.n_pack)), + dtype=self.compressed_dtype, + ).to(device), + ) + else: + self.register_buffer( + "packed_weight", + torch.zeros( + (math.ceil(out_features / self.n_pack), in_features), + dtype=self.compressed_dtype, + ).to(device), + ) + if zp: + self.register_buffer( + "packed_zp", + torch.zeros( + (math.ceil(self.out_features / self.n_pack), math.ceil(self.in_features / self.groupsize)), + dtype=self.compressed_dtype, + ).to(device), + ) + if bias: + self.register_buffer("bias", torch.zeros(self.out_features, dtype=self.float_type).to(device)) + else: + self.bias = None + + def pack(self, int_weight, scale, zp, bias, gptq_perm=None): + int_weight = int_weight.to(self.device) + if bias is not None: + assert hasattr(self, "bias"), "bias is not set when initializing." + self.bias = bias.type(self.float_type).to(self.device) + assert ( + scale.shape == self.scale.shape + ), f"Scale shape is mismatched, got self.scale.shape: {self.scale.shape} and scale.shape: {scale.shape}" + self.scale = scale.type(self.float_type).to(self.device) + if self.compression_dim == 0: + int_weight = int_weight.T + self.packed_weight = self.packed_weight.T + origin_shape = int_weight.shape + target_shape = self.packed_weight.shape + assert origin_shape[0] == target_shape[0], "output channels mismatch, please check." + mask = torch.tensor(2**self.bits - 1, dtype=self.compressed_dtype).to(self.device) + + # pack weight + for j in range(target_shape[1]): + start = self.n_pack * j + end = self.n_pack * (j + 1) + tmp = int_weight[:, start:end].type(self.compressed_dtype) + for e in range(tmp.shape[1]): + tmp[:, e] &= mask + tmp[:, e] = tmp[:, e] << (self.bits * e) + self.packed_weight[:, j] |= tmp[:, e] + if self.compression_dim == 0: + self.packed_weight = self.packed_weight.T + + if zp is not None: + zp = zp.to(self.device) + if self.compression_dim == 0: + zp = zp.T + self.packed_zp = self.packed_zp.T + assert hasattr(self, "packed_zp"), "zp is not set when initializing." + target_shape = self.packed_zp.shape + for j in range(target_shape[1]): + start = self.n_pack * j + end = self.n_pack * (j + 1) + tmp = zp[:, start:end].type(self.compressed_dtype) + for e in range(tmp.shape[1]): + tmp[:, e] &= mask + tmp[:, e] = tmp[:, e] << (self.bits * e) + self.packed_zp[:, j] |= tmp[:, e] + if self.compression_dim == 0: + self.packed_zp = self.packed_zp.T + + def recover(self): + logger.debug(f"Recovering {self} weight") + device = self.scale.device + mask = torch.tensor(2**self.bits - 1, dtype=self.compressed_dtype).to(device) + weight_dtype = torch.int8 + # unpack weight + weight = torch.zeros(self.out_features, self.in_features, dtype=weight_dtype).to(device) + packed_weight = self.packed_weight + if self.compression_dim == 0: + weight = weight.T + packed_weight = packed_weight.T + origin_shape = weight.shape + target_shape = packed_weight.shape + for j in range(target_shape[1]): + for e in range(self.n_pack): + index = j * self.n_pack + e + if index >= origin_shape[1]: + continue + tmp = packed_weight[:, j] + tmp = tmp << (self.compress_bits - self.bits * (e + 1)) + tmp = tmp >> self.compress_bits - self.bits + if weight_dtype == torch.uint8: + tmp &= mask # remove sign bit + weight[:, index] = tmp.type(weight_dtype) + if self.compression_dim == 0: + weight = weight.T + if "int" not in self.dtype: + new_weight = torch.zeros(self.out_features, self.in_features).to(device) + for k, v in self.int2float_mapping.items(): + new_weight += torch.where(weight == k, v, 0) + weight = new_weight + + # recover fp32 weight with int_weight, scale + left_element = self.in_features % self.groupsize + if left_element != 0: + split_index = self.in_features // self.groupsize * self.groupsize + weight1 = weight[:, :split_index].reshape(-1, self.groupsize) + scale1 = self.scale[:, :-1].reshape(-1, 1) + weight1 = (weight1 * scale1).reshape(self.out_features, -1) + weight2 = weight[:, split_index:] + scale2 = self.scale[:, -1:] + weight2 = weight2 * scale2 + fp32_weight = torch.cat((weight1, weight2), dim=1) + else: + weight = weight.reshape(-1, self.groupsize) + scale = self.scale.reshape(-1, 1) + fp32_weight = (weight * scale).reshape(self.out_features, -1) + return fp32_weight + + def forward(self, input): + weight = self.recover() + input = input.type(weight.dtype) + return F.linear(input, weight, self.bias) + + def extra_repr(self) -> str: + return "in_features={}, out_features={}, bits={}, group_size={}, bias={}".format( + self.in_features, self.out_features, self.bits, self.groupsize, self.bias is not None + ) + + +def rtn_quantize( + model, + num_bits=4, + group_size=32, + scheme="asym", + quantile=1.0, + weight_config={}, + return_int=False, + data_type="int", + enable_full_range=False, + enable_mse_search=False, + group_dim=1, + **kwargs, +): + """Quant the model with round to nearst method. + + Args: + model: torch module + num_bits: num bits. Defaults to 4. + group_size (int, optional): how many elements share one scale/zp. Defaults to 32. + scheme (str, optional): sym or asym. Defaults to "asym". + quantile (float, optional): percentile of clip. Defaults to 1.0. + data_type (str, optional): select from int, nf4, fp4. Defaults to int. + weight_config (dict, optional): specific layer wise configurations. Defaults to {}. + For example, + weight_config={ + 'fc2': + { + 'bits': 4, + 'group_size': 32, + 'scheme': 'sym' + 'gptq_perm': [1, 1, ...] # for gptq perm + } + } + return_int (bool, optional): Choose return fp32 or int32 model. + Defaults to False. + enable_full_range (bool, optional): Choose sym range whether use -2**(bits-1). + Defaults to False. + enable_mse_search (bool, optional): Whether search clip range. + Defaults to True. + group_dim (int, optional): 0 means splitting output channel, + 1 means splitting input channel. Defaults to 1. + + Returns: + model: fake quantized torch module + """ + assert isinstance(model, torch.nn.Module), "only support torch module" + supported_layers = ["Linear"] + if return_int: + compression_dtype = kwargs.get("compression_dtype", torch.int32) + compression_dim = kwargs.get("compression_dim", 1) + scale_dtype = kwargs.get("scale_dtype", torch.float32) + device = kwargs.get("device", "cpu") + for name, m in model.named_modules(): + if m.__class__.__name__ not in supported_layers: + continue + if name in weight_config: # pragma: no cover + num_bits = weight_config[name]["bits"] + group_size = weight_config[name]["group_size"] + scheme = weight_config[name]["scheme"] + quantile = weight_config[name].get("quantile", 1.0) + logger.debug(f"RTN quantized module:{name, m}") + log_msg = ( + f"RTN quantization config: num_bits={num_bits}, group_size={group_size}, " + + f"scheme={scheme}, quantile={quantile}" + ) + if data_type != "int": + log_msg += f", dtype={data_type}" + elif scheme == "sym": # nf4/fp4 is always [-7,7] + log_msg += f", enable_full_range={enable_full_range}" + logger.debug(log_msg) + weight = m.weight.T if group_dim == 0 else m.weight + if enable_mse_search: + quantile = search_clip(m, num_bits, group_size, scheme, data_type, enable_full_range) + if return_int: + int_weight, scale, zp = quant_weight( + weight, + num_bits, + group_size, + scheme, + quantile, + data_type=data_type, + return_int=True, + full_range=enable_full_range, + ) + int_weight = int_weight.T if group_dim == 0 else int_weight + scale = scale.T if group_dim == 0 else scale + zp = zp.T if group_dim == 0 and zp is not None else zp + new_module = WeightOnlyLinear( + m.in_features, + m.out_features, + num_bits, + group_size, + dtype=data_type, + zp=zp is not None, + bias=m.bias is not None, + compression_dtype=compression_dtype, + compression_dim=compression_dim, + scale_dtype=scale_dtype, + device=device, + ) + new_module.pack(int_weight, scale, zp, m.bias) + if name == "": + return new_module + else: + set_module(model, name, new_module) + else: + q_weight = quant_weight( + weight, + num_bits, + group_size, + scheme, + quantile, + data_type=data_type, + full_range=enable_full_range, + ) + q_weight = q_weight.T if group_dim == 0 else q_weight + m.weight.data.copy_(q_weight) + return model diff --git a/neural_compressor/torch/algorithms/rtn_quantize.py b/neural_compressor/torch/algorithms/rtn_quantize.py index 6f2592b9824..55e9fd31f4d 100644 --- a/neural_compressor/torch/algorithms/rtn_quantize.py +++ b/neural_compressor/torch/algorithms/rtn_quantize.py @@ -17,15 +17,14 @@ import torch -from neural_compressor.adaptor.torch_utils.util import fetch_module, set_module - -# TODO(Yi) move the algorithm implementations from adaptor.torch_utils to neural_compressor.torch.algo -from neural_compressor.adaptor.torch_utils.weight_only import rtn_quantize as torch_rtn_quantize from neural_compressor.common.base_config import BaseConfig +from neural_compressor.common.logger import Logger from neural_compressor.common.utility import RTN_WEIGHT_ONLY_QUANT +from neural_compressor.torch.algorithms.rtn import rtn_quantize as torch_rtn_quantize from neural_compressor.torch.quantization.config import RTNWeightQuantConfig -from neural_compressor.torch.utils import register_algo -from neural_compressor.utils import logger +from neural_compressor.torch.utils import fetch_module, register_algo, set_module + +logger = Logger().get_logger() def _apply_rtn_on_single_module(module: torch.nn.Module, quant_config: RTNWeightQuantConfig) -> torch.nn.Module: @@ -34,14 +33,15 @@ def _apply_rtn_on_single_module(module: torch.nn.Module, quant_config: RTNWeight group_dim = quant_config.group_dim dtype = quant_config.weight_dtype num_bits = quant_config.weight_bits - scheme = quant_config.weight_sym + scheme = "sym" if quant_config.weight_sym else "asym" group_size = quant_config.weight_group_size + return_int = quant_config.return_int return torch_rtn_quantize( module, num_bits, group_size, scheme, - return_int=False, + return_int=return_int, data_type=dtype, enable_full_range=enable_full_range, enable_mse_search=enable_mse_search, diff --git a/neural_compressor/torch/quantization/__init__.py b/neural_compressor/torch/quantization/__init__.py index fc7d7452422..e159bf99bad 100644 --- a/neural_compressor/torch/quantization/__init__.py +++ b/neural_compressor/torch/quantization/__init__.py @@ -13,4 +13,9 @@ # limitations under the License. from neural_compressor.torch.quantization.quantize import quantize -from neural_compressor.torch.quantization.config import RTNWeightQuantConfig, get_default_rtn_config +from neural_compressor.torch.quantization.config import ( + RTNWeightQuantConfig, + get_default_rtn_config, + DummyConfig, + get_default_dummy_config, +) diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 177ce78853e..16de62fab36 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -23,7 +23,7 @@ import torch from neural_compressor.common.base_config import BaseConfig, register_config, registered_configs -from neural_compressor.common.utility import RTN_WEIGHT_ONLY_QUANT +from neural_compressor.common.utility import DUMMY_CONFIG, RTN_WEIGHT_ONLY_QUANT FRAMEWORK_NAME = "torch" @@ -61,6 +61,7 @@ class RTNWeightQuantConfig(BaseConfig): "enable_full_range", "enable_mse_search", "group_dim", + "return_int", ] name = RTN_WEIGHT_ONLY_QUANT @@ -74,6 +75,7 @@ def __init__( enable_full_range: bool = False, enable_mse_search: bool = False, group_dim: int = 1, + return_int: bool = False, ): """Init RTN weight-only quantization config. @@ -86,6 +88,7 @@ def __init__( enable_full_range (bool): Enables full range for activations, default is False. enable_mse_search (bool): Enables mean squared error (MSE) search, default is False. group_dim (int): Dimension for grouping, default is 1. + return_int (bool): Enables return model in int8/uint8 format or not. Defaults to False. """ super().__init__() self.weight_bits = weight_bits @@ -96,6 +99,7 @@ def __init__( self.enable_full_range = enable_full_range self.enable_mse_search = enable_mse_search self.group_dim = group_dim + self.return_int = return_int def to_dict(self): return super().to_dict(params_list=self.params_list, operator2str=operator2str) @@ -113,6 +117,9 @@ def register_supported_configs(cls) -> List[OperatorConfig]: weight_group_size=[32, -1, 1, 4, 8, 16, 64, 128, 256, 512, 1024], weight_sym=[True, False], act_dtype=["fp32"], + enable_full_range=[False, True], + enable_mse_search=[False, True], + group_dim=[1, 0], ) operators = [torch.nn.Linear, torch.nn.functional.linear] supported_configs.append(OperatorConfig(config=linear_rtn_config, operators=operators, backend=Backend.DEFAULT)) @@ -123,6 +130,75 @@ def register_supported_configs(cls) -> List[OperatorConfig]: RTNWeightQuantConfig.register_supported_configs() +def get_default_rtn_config() -> RTNWeightQuantConfig: + """Generate the default rtn config. + + Returns: + the default rtn config. + """ + return RTNWeightQuantConfig() + + +@register_config(framework_name=FRAMEWORK_NAME, algo_name=DUMMY_CONFIG) +class DummyConfig(BaseConfig): + """Config class for round-to-nearest weight-only quantization.""" + + supported_configs: List[OperatorConfig] = [] + params_list = ["act_dtype", "weight_dtype", "dummy_attr"] + name = DUMMY_CONFIG + + def __init__( + self, + weight_dtype: str = "int", + act_dtype: str = "fp32", + dummy_attr: int = 0, + ): + """Init RTN weight-only quantization config. + + Args: + act_dtype (str): Data type for activations, default is "fp32". + weight_dtype (str): Data type for weights, default is "int". + dummy_attr (int): Dummy attribute, default is 0. + """ + super().__init__() + self.act_dtype = act_dtype + self.weight_dtype = weight_dtype + self.dummy_attr = dummy_attr + + def to_dict(self): + return super().to_dict(params_list=self.params_list, operator2str=operator2str) + + @classmethod + def from_dict(cls, config_dict): + return super(DummyConfig, cls).from_dict(config_dict=config_dict, str2operator=str2operator) + + @classmethod + def register_supported_configs(cls) -> List[OperatorConfig]: + supported_configs = [] + linear_dummy_config = DummyConfig( + act_dtype=["fp32"], + weight_dtype=["int4", "int8"], + dummy_attr=[1, 2, 3], + ) + operators = [torch.nn.Linear, torch.nn.functional.linear] + supported_configs.append( + OperatorConfig(config=linear_dummy_config, operators=operators, backend=Backend.DEFAULT) + ) + cls.supported_configs = supported_configs + + +def get_default_dummy_config() -> DummyConfig: + """Generate the default dummy config. + + Returns: + the default dummy config. + """ + return DummyConfig() + + +##################### Algo Configs End ################################### + + def get_all_registered_configs() -> Dict[str, BaseConfig]: return registered_configs.get(FRAMEWORK_NAME, {}) @@ -134,12 +210,3 @@ def parse_config_from_dict(config_dict: Dict) -> BaseConfig: config = torch_registered_configs[key].from_dict(val) return config # TODO(Yi) parse multiple configs after support configs add - - -def get_default_rtn_config() -> RTNWeightQuantConfig: - """Generate the default rtn config. - - Returns: - the default rtn config. - """ - return RTNWeightQuantConfig() diff --git a/neural_compressor/torch/quantization/quantize.py b/neural_compressor/torch/quantization/quantize.py index ff65e85b0fa..e53023ac363 100644 --- a/neural_compressor/torch/quantization/quantize.py +++ b/neural_compressor/torch/quantization/quantize.py @@ -17,12 +17,12 @@ import torch from neural_compressor.common.base_config import BaseConfig +from neural_compressor.common.logger import Logger from neural_compressor.common.utility import RTN_WEIGHT_ONLY_QUANT from neural_compressor.torch.quantization.config import parse_config_from_dict from neural_compressor.torch.utils import algos_mapping -# TODO (Yi) move logger into common in next PR -from neural_compressor.utils import logger +logger = Logger().get_logger() def quantize( diff --git a/neural_compressor/torch/utils.py b/neural_compressor/torch/utils.py index a32ae94eaf3..134bb14797c 100644 --- a/neural_compressor/torch/utils.py +++ b/neural_compressor/torch/utils.py @@ -13,11 +13,17 @@ # limitations under the License. -from typing import Callable, Dict +from typing import Callable, Dict, List, Tuple + +from neural_compressor.common.logger import Logger + +logger = Logger().get_logger() # Dictionary to store a mapping between algorithm names and corresponding algo implementation(function) algos_mapping: Dict[str, Callable] = {} +import torch + def register_algo(name): """Decorator function to register algorithms in the algos_mapping dictionary. @@ -39,3 +45,61 @@ def decorator(algo_func): return algo_func return decorator + + +def fetch_module(model, op_name): + """Get module with a given op name. + + Args: + model (object): the input model. + op_name (str): name of op. + + Returns: + module (object). + """ + module = model + name_list = op_name.split(".") + for name in name_list: + if hasattr(module, name): + module = getattr(module, name) + else: + logger.warning(f"The {op_name} is not present in the model.") + return None + return module + + +def set_module(model, op_name, new_module): + """Set module with a given op name. + + Args: + model (object): the input model. + op_name (str): name of op. + new_module (object): the input model. + + Returns: + module (object). + """ + name_list = op_name.split(".") + if len(name_list) == 1: + setattr(model, name_list[-1], new_module) + return + else: + second_last_module = fetch_module(model, ".".join(name_list[:-1])) + if second_last_module is None: + logger.warning(f"Setting skipped as the {op_name} is not present in the model.") + return None + else: + setattr(second_last_module, name_list[-1], new_module) + + +def get_model_info(model: torch.nn.Module, white_module_list: List[Callable]) -> List[Tuple[str, Callable]]: + module_dict = dict(model.named_modules()) + filter_result = [] + filter_result_set = set() + for op_name, module in module_dict.items(): + if isinstance(module, tuple(white_module_list)): + pair = (op_name, type(module)) + if pair not in filter_result_set: + filter_result_set.add(pair) + filter_result.append(pair) + return filter_result diff --git a/test/3x/torch/test_config.py b/test/3x/torch/test_config.py index 9ea17a4d0af..e366873eaea 100644 --- a/test/3x/torch/test_config.py +++ b/test/3x/torch/test_config.py @@ -1,6 +1,6 @@ import unittest -from neural_compressor.utils.logger import Logger +from neural_compressor.common.logger import Logger logger = Logger().get_logger() import torch @@ -35,7 +35,7 @@ def tearDownClass(self): def setUp(self): # print the test name - logger.info("Running TestQuantizationConfig test: %s".format()) + logger.info(f"Running TestQuantizationConfig test: {self.id()}") def test_quantize_rtn_from_dict_default(self): logger.info("test_quantize_rtn_from_dict_default") @@ -78,7 +78,7 @@ def test_quantize_rtn_from_dict_advance(self): "weight_bits": 4, "weight_group_size": 32, }, - "operator_name": { + "local": { "fc1": { "weight_dtype": "int8", "weight_bits": 4, @@ -93,12 +93,9 @@ def test_quantize_rtn_from_class_advance(self): from neural_compressor.torch import RTNWeightQuantConfig, quantize quant_config = RTNWeightQuantConfig(weight_bits=4, weight_dtype="nf4") - # # set operator type - # linear_config = RTNWeightQuantConfig(weight_bits=6, weight_dtype="nf4") - # quant_config._set_operator_type(torch.nn.Linear, linear_config) # set operator instance fc1_config = RTNWeightQuantConfig(weight_bits=4, weight_dtype="int8") - quant_config.set_operator_name("model.fc1", fc1_config) + quant_config.set_local("model.fc1", fc1_config) # get model and quantize fp32_model = build_simple_torch_model() qmodel = quantize(fp32_model, quant_config) @@ -114,7 +111,7 @@ def test_config_from_dict(self): "weight_bits": 4, "weight_group_size": 32, }, - "operator_name": { + "local": { "fc1": { "weight_dtype": "int8", "weight_bits": 4, @@ -123,17 +120,107 @@ def test_config_from_dict(self): } } config = RTNWeightQuantConfig.from_dict(quant_config) - self.assertIsNotNone(config.operator_name_config) + self.assertIsNotNone(config.local_config) def test_config_to_dict(self): from neural_compressor.torch import RTNWeightQuantConfig quant_config = RTNWeightQuantConfig(weight_bits=4, weight_dtype="nf4") fc1_config = RTNWeightQuantConfig(weight_bits=4, weight_dtype="int8") - quant_config.set_operator_name("model.fc1", fc1_config) + quant_config.set_local("model.fc1", fc1_config) config_dict = quant_config.to_dict() self.assertIn("global", config_dict) - self.assertIn("operator_name", config_dict) + self.assertIn("local", config_dict) + + def test_same_type_configs_addition(self): + from neural_compressor.torch import RTNWeightQuantConfig + + quant_config1 = { + "rtn_weight_only_quant": { + "weight_dtype": "nf4", + "weight_bits": 4, + "weight_group_size": 32, + }, + } + q_config = RTNWeightQuantConfig.from_dict(quant_config1["rtn_weight_only_quant"]) + quant_config2 = { + "rtn_weight_only_quant": { + "global": { + "weight_bits": 8, + "weight_group_size": 32, + }, + "local": { + "fc1": { + "weight_dtype": "int8", + "weight_bits": 4, + } + }, + } + } + q_config2 = RTNWeightQuantConfig.from_dict(quant_config2["rtn_weight_only_quant"]) + q_config3 = q_config + q_config2 + q3_dict = q_config3.to_dict() + for op_name, op_config in quant_config2["rtn_weight_only_quant"]["local"].items(): + for attr, val in op_config.items(): + self.assertEqual(q3_dict["local"][op_name][attr], val) + self.assertNotEqual( + q3_dict["global"]["weight_bits"], quant_config2["rtn_weight_only_quant"]["global"]["weight_bits"] + ) + + def test_diff_types_configs_addition(self): + from neural_compressor.torch import DummyConfig, RTNWeightQuantConfig + + quant_config1 = { + "rtn_weight_only_quant": { + "weight_dtype": "nf4", + "weight_bits": 4, + "weight_group_size": 32, + }, + } + q_config = RTNWeightQuantConfig.from_dict(quant_config1["rtn_weight_only_quant"]) + d_config = DummyConfig(act_dtype="fp32", dummy_attr=3) + combined_config = q_config + d_config + combined_config_d = combined_config.to_dict() + logger.info(combined_config) + self.assertTrue("rtn_weight_only_quant" in combined_config_d) + self.assertIn("dummy_config", combined_config_d) + + def test_composable_config_addition(self): + from neural_compressor.torch import DummyConfig, RTNWeightQuantConfig + + quant_config1 = { + "rtn_weight_only_quant": { + "weight_dtype": "nf4", + "weight_bits": 4, + "weight_group_size": 32, + }, + } + q_config = RTNWeightQuantConfig.from_dict(quant_config1["rtn_weight_only_quant"]) + d_config = DummyConfig(act_dtype="fp32", dummy_attr=3) + combined_config = q_config + d_config + combined_config_d = combined_config.to_dict() + logger.info(combined_config) + self.assertTrue("rtn_weight_only_quant" in combined_config_d) + self.assertIn("dummy_config", combined_config_d) + combined_config2 = combined_config + d_config + combined_config3 = combined_config + combined_config2 + + def test_config_mapping(self): + from neural_compressor.torch import RTNWeightQuantConfig + from neural_compressor.torch.utils import get_model_info + + quant_config = RTNWeightQuantConfig(weight_bits=4, weight_dtype="nf4") + # set operator instance + fc1_config = RTNWeightQuantConfig(weight_bits=6, weight_dtype="int8") + quant_config.set_local("fc1", fc1_config) + # get model and quantize + fp32_model = build_simple_torch_model() + model_info = get_model_info(fp32_model, white_module_list=[torch.nn.Linear]) + logger.info(quant_config) + configs_mapping = quant_config.to_config_mapping(model_info=model_info) + logger.info(configs_mapping) + self.assertTrue(configs_mapping[torch.nn.Linear]["fc1"].weight_bits == 6) + self.assertTrue(configs_mapping[torch.nn.Linear]["fc2"].weight_bits == 4) if __name__ == "__main__": diff --git a/test/3x/torch/test_logger.py b/test/3x/torch/test_logger.py new file mode 100644 index 00000000000..c091eba8db4 --- /dev/null +++ b/test/3x/torch/test_logger.py @@ -0,0 +1,34 @@ +"""Tests for logging utilities.""" +import unittest + +from neural_compressor.common import logger + + +class TestLogger(unittest.TestCase): + def test_logger(self): + logger.log(0, "call logger log function.") + logger.log(1, {"msg": "call logger log function."}) + logger.debug("call logger debug function.") + logger.debug({"msg": "call logger debug function."}) + logger.error("call logger error function.") + logger.error({"msg": "call logger error function."}) + logger.fatal("call logger fatal function") + logger.fatal({"msg": "call logger fatal function"}) + logger.info("call logger info function") + logger.info({"msg": "call logger info function."}) + logger.warn("call logger warn function") + logger.warn({"msg": "call logger warn function"}) + logger.warning("call logger warning function") + logger.warning({"msg": "call logger warning function"}) + logger.warning(["call logger warning function", "done"]) + logger.warning(("call logger warning function", "done")) + logger.warning({"msg": {("bert", "embedding"): {"weight": {"dtype": ["unint8", "int8"]}}}}) + logger.warning({"msg": {("bert", "embedding"): {"op": ("a", "b")}}}) + # the following log will not be prettified + logger.warning([{"msg": "call logger warning function"}, {"msg2": "done"}]) + logger.warning(({"msg": "call logger warning function"}, {"msg2": "done"})) + logger.warning(({"msg": [{"sub_msg": "call logger"}, {"sub_msg2": "call warning function"}]}, {"msg2": "done"})) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/3x/torch/test_rtn.py b/test/3x/torch/test_rtn.py new file mode 100644 index 00000000000..879ecc1b13e --- /dev/null +++ b/test/3x/torch/test_rtn.py @@ -0,0 +1,136 @@ +import unittest + +import torch + +from neural_compressor.common.logger import Logger + +logger = Logger().get_logger() + + +def get_gpt_j(): + import transformers + + tiny_gptj = transformers.AutoModelForCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-GPTJForCausalLM", + torchscript=True, + ) + return tiny_gptj + + +def build_simple_torch_model(): + class Model(torch.nn.Module): + def __init__(self): + super(Model, self).__init__() + self.fc1 = torch.nn.Linear(8, 30) + self.fc2 = torch.nn.Linear(30, 60) + self.fc3 = torch.nn.Linear(60, 30) + self.fc4 = torch.nn.Linear(30, 50) + + def forward(self, x): + out = self.fc1(x) + out = self.fc2(out) + out = self.fc3(out) + out = self.fc4(out) + return out + + model = Model() + return model + + +class TestRTNQuant(unittest.TestCase): + @classmethod + def setUpClass(self): + pass + + @classmethod + def tearDownClass(self): + pass + + def setUp(self): + # print the test name + logger.info(f"Running TestRTNQuant test: {self.id()}") + + def _apply_rtn(self, quant_config): + logger.info(f"Test RTN with config {quant_config}") + from neural_compressor.torch import quantize + + fp32_model = build_simple_torch_model() + qmodel = quantize(fp32_model, quant_config) + self.assertIsNotNone(qmodel) + return qmodel + + def test_rtn(self): + from neural_compressor.torch import RTNWeightQuantConfig + + # some tests were skipped to accelerate the CI + rnt_options = { + "weight_dtype": ["int", "int8", "nf4", "fp4_e2m1_bnb"], + "weight_bits": [4, 1, 8], + "weight_group_size": [32, -1, 1, 512], + "weight_sym": [True, False], + "act_dtype": ["fp32"], + "enable_full_range": [False, True], + "enable_mse_search": [False], + "group_dim": [1, 0], + "return_int": [False, True], + } + from itertools import product + + keys = RTNWeightQuantConfig.params_list + for value in product(*rnt_options.values()): + d = dict(zip(keys, value)) + if (d["weight_dtype"] == "int" and d["weight_bits"] != 8) or ( + d["enable_full_range"] + and d["enable_mse_search"] + or (d["return_int"] and (d["group_dim"] != 1 or d["weight_bits"] != 8)) + ): + continue + quant_config = RTNWeightQuantConfig(**d) + self._apply_rtn(quant_config) + + def test_rtn_return_type(self): + from neural_compressor.torch import RTNWeightQuantConfig + + for return_int in [True, False]: + quant_config = RTNWeightQuantConfig(return_int=return_int) + qmodel = self._apply_rtn(quant_config) + + def test_rtn_mse_search(self): + from neural_compressor.torch import RTNWeightQuantConfig + + quant_config = RTNWeightQuantConfig(enable_mse_search=True) + qmodel = self._apply_rtn(quant_config) + + def test_rtn_recover(self): + from neural_compressor.torch import RTNWeightQuantConfig + + quant_config = RTNWeightQuantConfig(return_int=True) + qmodel = self._apply_rtn(quant_config) + input = torch.randn(4, 8) + # test forward + out = qmodel(input) + recovered_fc1 = qmodel.fc1.recover() + self.assertIsNotNone(recovered_fc1) + + def test_weight_only_linear(self): + from neural_compressor.torch.algorithms.rtn import rtn_quantize + + model = build_simple_torch_model() + options = { + "compression_dtype": [torch.int8, torch.int16, torch.int32, torch.int64], + "compression_dim": [0, 1], + "module": [model.fc1, model.fc2, model.fc3, model.fc4], + } + from itertools import product + + for compression_dtype, compression_dim, module in product(*options.values()): + q_model = rtn_quantize( + model=module, + return_int=True, + compression_dtype=compression_dtype, + compression_dim=compression_dim, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/3x/torch/test_utils.py b/test/3x/torch/test_utils.py new file mode 100644 index 00000000000..c5e33bab611 --- /dev/null +++ b/test/3x/torch/test_utils.py @@ -0,0 +1,88 @@ +import unittest + +import torch + +from neural_compressor.common.logger import Logger + +logger = Logger().get_logger() + + +def get_gpt_j(): + import transformers + + tiny_gptj = transformers.AutoModelForCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-GPTJForCausalLM", + torchscript=True, + ) + return tiny_gptj + + +def build_simple_torch_model(): + class Model(torch.nn.Module): + def __init__(self): + super(Model, self).__init__() + self.fc1 = torch.nn.Linear(8, 30) + self.fc2 = torch.nn.Linear(30, 60) + self.fc3 = torch.nn.Linear(60, 30) + self.fc4 = torch.nn.Linear(30, 50) + + def forward(self, x): + out = self.fc1(x) + out = self.fc2(out) + out = self.fc3(out) + out = self.fc4(out) + return out + + model = Model() + return model + + +from neural_compressor.torch.utils import fetch_module, set_module + + +class TestTorchUtils(unittest.TestCase): + @classmethod + def setUpClass(self): + self.model = get_gpt_j() + + @classmethod + def tearDownClass(self): + pass + + def setUp(self): + # print the test name + logger.info(f"Running TestTorchUtils test: {self.id()}") + + def test_fetch_module(self): + result = fetch_module(self.model, "transformer.h.2.mlp.fc_in") + self.assertIsInstance(result, torch.nn.Linear) + + def test_set_module(self): + module_name = "transformer.h.2.mlp.fc_in" + mew_value = torch.nn.Linear(32, 128, bias=False) + set_module(self.model, module_name, mew_value) + result = fetch_module(self.model, module_name) + self.assertFalse(result.bias) + + def test_set_module_nonexistent_attribute(self): + new_value = torch.nn.Parameter(torch.Tensor([3.0])) + attr_name = "transformer.nonexistent_attr" + set_module(self.model, attr_name, new_value) + result = fetch_module(self.model, attr_name) + self.assertTrue(torch.equal(result, torch.Tensor([3.0]))) + + def test_fetch_module_nonexistent_attribute(self): + attr_name = "transformer.nonexistent_attr" + result = fetch_module(self.model, attr_name) + self.assertIsNone(result) + + def test_get_model_info(self): + from neural_compressor.torch.utils import get_model_info + + white_module_list = [torch.nn.Linear] + model_info = get_model_info(build_simple_torch_model(), white_module_list) + self.assertEqual(len(model_info), 4) + + +if __name__ == "__main__": + unittest.main()