diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py index 8235594a8d2..9d81fea6c08 100644 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py +++ b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py @@ -82,7 +82,7 @@ parser.add_argument("--double_quant_type", type=str, default=None, - choices=['GGML_TYPE_Q4_K', 'BNB'], + choices=['GGML_TYPE_Q4_K', 'BNB_NF4'], help="DoubleQuant parameter") parser.add_argument("--double_quant_dtype", type=str, @@ -230,8 +230,8 @@ def get_user_model(): # 3.x api if args.approach == 'weight_only': - from neural_compressor.torch import RTNConfig, GPTQConfig, quantize - from neural_compressor.torch.utils.utility import get_double_quant_config + from neural_compressor.torch.quantization import RTNConfig, GPTQConfig, quantize + from neural_compressor.torch.utils import get_double_quant_config weight_sym = True if args.woq_scheme == "sym" else False double_quant_config_dict = get_double_quant_config(args.double_quant_type, weight_sym=weight_sym) diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_quant.sh b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_quant.sh index 625362afa50..751d50512db 100644 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_quant.sh +++ b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_quant.sh @@ -50,7 +50,7 @@ function run_tuning { model_name_or_path="facebook/opt-125m" approach="weight_only" extra_cmd=$extra_cmd" --woq_algo GPTQ --woq_bits 4 --woq_group_size 128 --woq_scheme asym --woq_enable_mse_search --gptq_pad_max_length 2048 --gptq_use_max_length" - extra_cmd=$extra_cmd" --double_quant_type BNB" + extra_cmd=$extra_cmd" --double_quant_type BNB_NF4" elif [ "${topology}" = "opt_125m_woq_gptq_int4_dq_ggml" ]; then model_name_or_path="facebook/opt-125m" approach="weight_only" @@ -64,7 +64,7 @@ function run_tuning { model_name_or_path="meta-llama/Llama-2-7b-hf" approach="weight_only" extra_cmd=$extra_cmd" --woq_algo GPTQ --woq_bits 4 --woq_group_size 128 --woq_scheme asym --woq_enable_mse_search --gptq_pad_max_length 2048 --gptq_use_max_length" - extra_cmd=$extra_cmd" --double_quant_type BNB" + extra_cmd=$extra_cmd" --double_quant_type BNB_NF4" elif [ "${topology}" = "llama2_7b_gptq_int4_dq_ggml" ]; then model_name_or_path="meta-llama/Llama-2-7b-hf" approach="weight_only" @@ -78,7 +78,7 @@ function run_tuning { model_name_or_path="EleutherAI/gpt-j-6b" approach="weight_only" extra_cmd=$extra_cmd" --woq_algo RTN --woq_bits 4 --woq_group_size 128 --woq_scheme asym --woq_enable_mse_search" - extra_cmd=$extra_cmd" --double_quant_type BNB" + extra_cmd=$extra_cmd" --double_quant_type BNB_NF4" elif [ "${topology}" = "gpt_j_woq_rtn_int4_dq_ggml" ]; then model_name_or_path="EleutherAI/gpt-j-6b" approach="weight_only" @@ -92,7 +92,7 @@ function run_tuning { model_name_or_path="EleutherAI/gpt-j-6b" approach="weight_only" extra_cmd=$extra_cmd" --woq_algo GPTQ --woq_bits 4 --woq_group_size 128 --woq_scheme asym --woq_enable_mse_search --gptq_pad_max_length 2048 --gptq_use_max_length" - extra_cmd=$extra_cmd" --double_quant_type BNB" + extra_cmd=$extra_cmd" --double_quant_type BNB_NF4" elif [ "${topology}" = "gpt_j_woq_gptq_int4_dq_ggml" ]; then model_name_or_path="EleutherAI/gpt-j-6b" approach="weight_only" diff --git a/neural_compressor/torch/algorithms/layer_wise/__init__.py b/neural_compressor/torch/algorithms/layer_wise/__init__.py new file mode 100644 index 00000000000..453a608c4a9 --- /dev/null +++ b/neural_compressor/torch/algorithms/layer_wise/__init__.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Torch layer-wise quantization module.""" +from .utils import * diff --git a/neural_compressor/torch/algorithms/layer_wise/load.py b/neural_compressor/torch/algorithms/layer_wise/load.py new file mode 100644 index 00000000000..09700044a8f --- /dev/null +++ b/neural_compressor/torch/algorithms/layer_wise/load.py @@ -0,0 +1,272 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Load one specify tensor from a bin file.""" + +import io +import os +import warnings +from typing import IO, Any, BinaryIO, Callable, Dict, Optional, Union + +from packaging.version import Version +from torch.serialization import ( + StorageType, + _get_restore_location, + _is_torchscript_zip, + _is_zipfile, + _maybe_decode_ascii, + _open_file_like, + _open_zipfile_reader, +) + +from neural_compressor.adaptor.torch_utils.layer_wise_quant import modified_pickle as pickle + +from .utils import torch + +torch_version = torch.__version__.split("+")[0] +version = Version(torch_version) + +FILE_LIKE = Union[str, os.PathLike, BinaryIO, IO[bytes]] +MAP_LOCATION = Optional[Union[Callable[[torch.Tensor, str], torch.Tensor], torch.device, str, Dict[str, str]]] + +if version.release < Version("1.13.0").release: + UntypedStorage = torch._UntypedStorage +else: + UntypedStorage = torch.UntypedStorage + + +def _load(zip_file, tensor_name, prefix, map_location, pickle_module, pickle_file="data.pkl", **pickle_load_args): + restore_location = _get_restore_location(map_location) + + loaded_storages = {} + + def load_tensor(dtype, numel, key, location): + name = f"data/{key}" + + if version.release < Version("1.13.0").release: + storage = zip_file.get_storage_from_record(name, numel, torch._UntypedStorage).storage()._untyped() + typed_storage = torch.storage._TypedStorage(wrap_storage=restore_location(storage, location), dtype=dtype) + loaded_storages[key] = typed_storage + elif version.release < Version("2.0.0").release: # pragma: no cover + storage = zip_file.get_storage_from_record(name, numel, UntypedStorage).storage().untyped() + typed_storage = torch.storage.TypedStorage(wrap_storage=restore_location(storage, location), dtype=dtype) + loaded_storages[key] = typed_storage + else: + storage = zip_file.get_storage_from_record(name, numel, UntypedStorage)._typed_storage()._untyped_storage + typed_storage = torch.storage.TypedStorage( + wrap_storage=restore_location(storage, location), dtype=dtype, _internal=True + ) + + if typed_storage._data_ptr() != 0: + loaded_storages[key] = typed_storage + + return typed_storage + + load_module_mapping: Dict[str, str] = {"torch.tensor": "torch._tensor"} + + # Need to subclass Unpickler instead of directly monkey-patching the find_class method + # because it's marked readonly in pickle. + # The type: ignore is because mypy can't statically determine the type of this class. + class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined] + def find_class(self, mod_name, name): + if type(name) is str and "Storage" in name: + try: + return StorageType(name) + except KeyError: # pragma: no cover + pass + mod_name = load_module_mapping.get(mod_name, mod_name) + return super().find_class(mod_name, name) + + def persistent_load(self, saved_id): + assert isinstance(saved_id, tuple) + typename = _maybe_decode_ascii(saved_id[0]) + data = saved_id[1:] + + assert ( + typename == "storage" + ), f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'" + storage_type, key, location, numel = data + + if storage_type is UntypedStorage: # pragma: no cover + dtype = torch.uint8 + else: + dtype = storage_type.dtype + + if key in loaded_storages: + typed_storage = loaded_storages[key] + else: + name_list = [self.tensor_name] + if prefix: + no_prefix_name = self.tensor_name.split(".") + if prefix in no_prefix_name: + no_prefix_name.remove(prefix) + no_prefix_name = ".".join(no_prefix_name) + name_list.append(no_prefix_name) + if self.tensor_name and self.metastack[-1][-2] not in name_list: + # typed_storage = None + # loaded_storages[key] = typed_storage + # nbytes = numel * torch._utils._element_size(dtype) + # typed_storage = load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location)) + typed_storage = None + else: + nbytes = numel * torch._utils._element_size(dtype) + typed_storage = load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location)) + + return typed_storage + + # Load the data (which may in turn use `persistent_load` to load tensors) + data_file = io.BytesIO(zip_file.get_record(pickle_file)) + + unpickler = UnpicklerWrapper(data_file, **pickle_load_args) + # unpickler.persistent_load = persistent_load + result = unpickler.load(tensor_name) + + torch._utils._validate_loaded_sparse_tensors() + return result + + +def load( + f: FILE_LIKE, + tensor_name: str = None, + prefix: str = None, + map_location: MAP_LOCATION = None, + pickle_module: Any = None, + *, + weights_only: bool = False, + **pickle_load_args: Any, +) -> Any: + # Reference: https://github.com/pytorch/pytorch/issues/54354 + # The first line of this docstring overrides the one Sphinx generates for the + # documentation. We need it so that Sphinx doesn't leak `pickle`s path from + # the build environment (e.g. `>> # xdoctest: +SKIP("undefined filepaths") + >>> torch.load('tensors.pt') + # Load all tensors onto the CPU + >>> torch.load('tensors.pt', map_location=torch.device('cpu')) + # Load all tensors onto the CPU, using a function + >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage) + # Load all tensors onto GPU 1 + >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1)) + # Map tensors from GPU 1 to GPU 0 + >>> torch.load('tensors.pt', map_location={'cuda:1': 'cuda:0'}) + # Load tensor from io.BytesIO object + >>> with open('tensor.pt', 'rb') as f: + ... buffer = io.BytesIO(f.read()) + >>> torch.load(buffer) + # Load a module with 'ascii' encoding for unpickling + >>> torch.load('module.pt', encoding='ascii') + """ + torch._C._log_api_usage_once("torch.load") + # Add ability to force safe only weight loads via environment variable + if os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0").lower() in ["1", "y", "yes", "true"]: # pragma: no cover + weights_only = True + + if weights_only: # pragma: no cover + if pickle_module is not None: + raise RuntimeError("Can not safely load weights when explicit pickle_module is specified") + else: + if pickle_module is None: + pickle_module = pickle + + if "encoding" not in pickle_load_args.keys(): + pickle_load_args["encoding"] = "utf-8" + + with _open_file_like(f, "rb") as opened_file: + if _is_zipfile(opened_file): + # The zipfile reader is going to advance the current file position. + # If we want to actually tail call to torch.jit.load, we need to + # reset back to the original position. + orig_position = opened_file.tell() + with _open_zipfile_reader(opened_file) as opened_zipfile: + if _is_torchscript_zip(opened_zipfile): # pragma: no cover + warnings.warn( + "'torch.load' received a zip file that looks like a TorchScript archive" + " dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to" + " silence this warning)", + UserWarning, + ) + opened_file.seek(orig_position) + return torch.jit.load(opened_file, map_location=map_location) + return _load(opened_zipfile, tensor_name, prefix, map_location, pickle_module, **pickle_load_args) diff --git a/neural_compressor/torch/algorithms/layer_wise/modified_pickle.py b/neural_compressor/torch/algorithms/layer_wise/modified_pickle.py new file mode 100644 index 00000000000..eac4ce343e0 --- /dev/null +++ b/neural_compressor/torch/algorithms/layer_wise/modified_pickle.py @@ -0,0 +1,1861 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Create portable serialized representations of Python objects. + +See module copyreg for a mechanism for registering custom picklers. +See module pickletools source for extensive comments. + +Classes: + + Pickler + Unpickler + +Functions: + + dump(object, file) + dumps(object) -> string + load(file) -> object + loads(string) -> object + +Misc variables: + + __version__ + format_version + compatible_formats +""" + +import codecs +import io +import re +import sys +from copyreg import _extension_cache, _extension_registry, _inverted_registry, dispatch_table +from functools import partial +from itertools import islice +from struct import pack, unpack +from sys import maxsize +from types import FunctionType + +import _compat_pickle + +__all__ = ["PickleError", "PicklingError", "UnpicklingError", "Pickler", "Unpickler", "dump", "dumps", "load", "loads"] + +try: + from _pickle import PickleBuffer + + __all__.append("PickleBuffer") + _HAVE_PICKLE_BUFFER = True +except ImportError: + _HAVE_PICKLE_BUFFER = False + + +# Shortcut for use in isinstance testing +bytes_types = (bytes, bytearray) + +# These are purely informational; no code uses these. +format_version = "4.0" # File format version we write +compatible_formats = [ + "1.0", # Original protocol 0 + "1.1", # Protocol 0 with INST added + "1.2", # Original protocol 1 + "1.3", # Protocol 1 with BINFLOAT added + "2.0", # Protocol 2 + "3.0", # Protocol 3 + "4.0", # Protocol 4 + "5.0", # Protocol 5 +] # Old format versions we can read + +# This is the highest protocol number we know how to read. +HIGHEST_PROTOCOL = 5 + +# The protocol we write by default. May be less than HIGHEST_PROTOCOL. +# Only bump this if the oldest still supported version of Python already +# includes it. +DEFAULT_PROTOCOL = 4 + + +class PickleError(Exception): + """A common base class for the other pickling exceptions.""" + + pass + + +class PicklingError(PickleError): + """This exception is raised when an unpicklable object is passed to the + dump() method.""" + + pass + + +class UnpicklingError(PickleError): + """This exception is raised when there is a problem unpickling an object, + such as a security violation. + + Note that other exceptions may also be raised during unpickling, including + (but not necessarily limited to) AttributeError, EOFError, ImportError, + and IndexError. + """ + + pass + + +# An instance of _Stop is raised by Unpickler.load_stop() in response to +# the STOP opcode, passing the object that is the result of unpickling. +class _Stop(Exception): + def __init__(self, value): + self.value = value + + +# Jython has PyStringMap; it's a dict subclass with string keys +try: + from org.python.core import PyStringMap +except ImportError: + PyStringMap = None + +# Pickle opcodes. See pickletools.py for extensive docs. The listing +# here is in kind-of alphabetical order of 1-character pickle code. +# pickletools groups them by purpose. +# fmt: off +MARK = b'(' # push special markobject on stack +STOP = b'.' # every pickle ends with STOP +POP = b'0' # discard topmost stack item +POP_MARK = b'1' # discard stack top through topmost markobject +DUP = b'2' # duplicate top stack item +FLOAT = b'F' # push float object; decimal string argument +INT = b'I' # push integer or bool; decimal string argument +BININT = b'J' # push four-byte signed int +BININT1 = b'K' # push 1-byte unsigned int +LONG = b'L' # push long; decimal string argument +BININT2 = b'M' # push 2-byte unsigned int +NONE = b'N' # push None +PERSID = b'P' # push persistent object; id is taken from string arg +BINPERSID = b'Q' # " " " ; " " " " stack +REDUCE = b'R' # apply callable to argtuple, both on stack +STRING = b'S' # push string; NL-terminated string argument +BINSTRING = b'T' # push string; counted binary string argument +SHORT_BINSTRING= b'U' # " " ; " " " " < 256 bytes +UNICODE = b'V' # push Unicode string; raw-unicode-escaped'd argument +BINUNICODE = b'X' # " " " ; counted UTF-8 string argument +APPEND = b'a' # append stack top to list below it +BUILD = b'b' # call __setstate__ or __dict__.update() +GLOBAL = b'c' # push self.find_class(modname, name); 2 string args +DICT = b'd' # build a dict from stack items +EMPTY_DICT = b'}' # push empty dict +APPENDS = b'e' # extend list on stack by topmost stack slice +GET = b'g' # push item from memo on stack; index is string arg +BINGET = b'h' # " " " " " " ; " " 1-byte arg +INST = b'i' # build & push class instance +LONG_BINGET = b'j' # push item from memo on stack; index is 4-byte arg +LIST = b'l' # build list from topmost stack items +EMPTY_LIST = b']' # push empty list +OBJ = b'o' # build & push class instance +PUT = b'p' # store stack top in memo; index is string arg +BINPUT = b'q' # " " " " " ; " " 1-byte arg +LONG_BINPUT = b'r' # " " " " " ; " " 4-byte arg +SETITEM = b's' # add key+value pair to dict +TUPLE = b't' # build tuple from topmost stack items +EMPTY_TUPLE = b')' # push empty tuple +SETITEMS = b'u' # modify dict by adding topmost key+value pairs +BINFLOAT = b'G' # push float; arg is 8-byte float encoding + +TRUE = b'I01\n' # not an opcode; see INT docs in pickletools.py +FALSE = b'I00\n' # not an opcode; see INT docs in pickletools.py + +# Protocol 2 + +PROTO = b'\x80' # identify pickle protocol +NEWOBJ = b'\x81' # build object by applying cls.__new__ to argtuple +EXT1 = b'\x82' # push object from extension registry; 1-byte index +EXT2 = b'\x83' # ditto, but 2-byte index +EXT4 = b'\x84' # ditto, but 4-byte index +TUPLE1 = b'\x85' # build 1-tuple from stack top +TUPLE2 = b'\x86' # build 2-tuple from two topmost stack items +TUPLE3 = b'\x87' # build 3-tuple from three topmost stack items +NEWTRUE = b'\x88' # push True +NEWFALSE = b'\x89' # push False +LONG1 = b'\x8a' # push long from < 256 bytes +LONG4 = b'\x8b' # push really big long + +_tuplesize2code = [EMPTY_TUPLE, TUPLE1, TUPLE2, TUPLE3] + +# Protocol 3 (Python 3.x) + +BINBYTES = b'B' # push bytes; counted binary string argument +SHORT_BINBYTES = b'C' # " " ; " " " " < 256 bytes + +# Protocol 4 + +SHORT_BINUNICODE = b'\x8c' # push short string; UTF-8 length < 256 bytes +BINUNICODE8 = b'\x8d' # push very long string +BINBYTES8 = b'\x8e' # push very long bytes string +EMPTY_SET = b'\x8f' # push empty set on the stack +ADDITEMS = b'\x90' # modify set by adding topmost stack items +FROZENSET = b'\x91' # build frozenset from topmost stack items +NEWOBJ_EX = b'\x92' # like NEWOBJ but work with keyword only arguments +STACK_GLOBAL = b'\x93' # same as GLOBAL but using names on the stacks +MEMOIZE = b'\x94' # store top of the stack in memo +FRAME = b'\x95' # indicate the beginning of a new frame + +# Protocol 5 + +BYTEARRAY8 = b'\x96' # push bytearray +NEXT_BUFFER = b'\x97' # push next out-of-band buffer +READONLY_BUFFER = b'\x98' # make top of stack readonly +# fmt: on +__all__.extend([x for x in dir() if re.match("[A-Z][A-Z0-9_]+$", x)]) + + +class _Framer: # pragma: no cover + _FRAME_SIZE_MIN = 4 + _FRAME_SIZE_TARGET = 64 * 1024 + + def __init__(self, file_write): + self.file_write = file_write + self.current_frame = None + + def start_framing(self): + self.current_frame = io.BytesIO() + + def end_framing(self): + if self.current_frame and self.current_frame.tell() > 0: + self.commit_frame(force=True) + self.current_frame = None + + def commit_frame(self, force=False): + if self.current_frame: + f = self.current_frame + if f.tell() >= self._FRAME_SIZE_TARGET or force: + data = f.getbuffer() + write = self.file_write + if len(data) >= self._FRAME_SIZE_MIN: + # Issue a single call to the write method of the underlying + # file object for the frame opcode with the size of the + # frame. The concatenation is expected to be less expensive + # than issuing an additional call to write. + write(FRAME + pack("": + raise AttributeError("Can't get local attribute {!r} on {!r}".format(name, obj)) + try: + parent = obj + obj = getattr(obj, subpath) + except AttributeError: + raise AttributeError("Can't get attribute {!r} on {!r}".format(name, obj)) from None + return obj, parent + + +def whichmodule(obj, name): # pragma: no cover + """Find the module an object belong to.""" + module_name = getattr(obj, "__module__", None) + if module_name is not None: + return module_name + # Protect the iteration by using a list copy of sys.modules against dynamic + # modules that trigger imports of other modules upon calls to getattr. + for module_name, module in sys.modules.copy().items(): + if module_name == "__main__" or module_name == "__mp_main__" or module is None: # bpo-42406 + continue + try: + if _getattribute(module, name)[0] is obj: + return module_name + except AttributeError: + pass + return "__main__" + + +def encode_long(x): # pragma: no cover + r"""Encode a long to a two's complement little-endian binary string. + Note that 0 is a special case, returning an empty string, to save a + byte in the LONG1 pickling context. + + >>> encode_long(0) + b'' + >>> encode_long(255) + b'\xff\x00' + >>> encode_long(32767) + b'\xff\x7f' + >>> encode_long(-256) + b'\x00\xff' + >>> encode_long(-32768) + b'\x00\x80' + >>> encode_long(-128) + b'\x80' + >>> encode_long(127) + b'\x7f' + >>> + """ + if x == 0: + return b"" + nbytes = (x.bit_length() >> 3) + 1 + result = x.to_bytes(nbytes, byteorder="little", signed=True) + if x < 0 and nbytes > 1: + if result[-1] == 0xFF and (result[-2] & 0x80) != 0: + result = result[:-1] + return result + + +def decode_long(data): # pragma: no cover + r"""Decode a long from a two's complement little-endian binary string. + + >>> decode_long(b'') + 0 + >>> decode_long(b"\xff\x00") + 255 + >>> decode_long(b"\xff\x7f") + 32767 + >>> decode_long(b"\x00\xff") + -256 + >>> decode_long(b"\x00\x80") + -32768 + >>> decode_long(b"\x80") + -128 + >>> decode_long(b"\x7f") + 127 + """ + return int.from_bytes(data, byteorder="little", signed=True) + + +# Pickling machinery + + +class _Pickler: # pragma: no cover + def __init__(self, file, protocol=None, *, fix_imports=True, buffer_callback=None): + """This takes a binary file for writing a pickle data stream. + + The optional *protocol* argument tells the pickler to use the + given protocol; supported protocols are 0, 1, 2, 3, 4 and 5. + The default protocol is 4. It was introduced in Python 3.4, and + is incompatible with previous versions. + + Specifying a negative protocol version selects the highest + protocol version supported. The higher the protocol used, the + more recent the version of Python needed to read the pickle + produced. + + The *file* argument must have a write() method that accepts a + single bytes argument. It can thus be a file object opened for + binary writing, an io.BytesIO instance, or any other custom + object that meets this interface. + + If *fix_imports* is True and *protocol* is less than 3, pickle + will try to map the new Python 3 names to the old module names + used in Python 2, so that the pickle data stream is readable + with Python 2. + + If *buffer_callback* is None (the default), buffer views are + serialized into *file* as part of the pickle stream. + + If *buffer_callback* is not None, then it can be called any number + of times with a buffer view. If the callback returns a false value + (such as None), the given buffer is out-of-band; otherwise the + buffer is serialized in-band, i.e. inside the pickle stream. + + It is an error if *buffer_callback* is not None and *protocol* + is None or smaller than 5. + """ + if protocol is None: + protocol = DEFAULT_PROTOCOL + if protocol < 0: + protocol = HIGHEST_PROTOCOL + elif not 0 <= protocol <= HIGHEST_PROTOCOL: + raise ValueError("pickle protocol must be <= %d" % HIGHEST_PROTOCOL) + if buffer_callback is not None and protocol < 5: + raise ValueError("buffer_callback needs protocol >= 5") + self._buffer_callback = buffer_callback + try: + self._file_write = file.write + except AttributeError: + raise TypeError("file must have a 'write' attribute") + self.framer = _Framer(self._file_write) + self.write = self.framer.write + self._write_large_bytes = self.framer.write_large_bytes + self.memo = {} + self.proto = int(protocol) + self.bin = protocol >= 1 + self.fast = 0 + self.fix_imports = fix_imports and protocol < 3 + + def clear_memo(self): + """Clears the pickler's "memo". + + The memo is the data structure that remembers which objects the + pickler has already seen, so that shared or recursive objects + are pickled by reference and not by value. This method is + useful when re-using picklers. + """ + self.memo.clear() + + def dump(self, obj): + """Write a pickled representation of obj to the open file.""" + # Check whether Pickler was initialized correctly. This is + # only needed to mimic the behavior of _pickle.Pickler.dump(). + if not hasattr(self, "_file_write"): + raise PicklingError("Pickler.__init__() was not called by " "%s.__init__()" % (self.__class__.__name__,)) + if self.proto >= 2: + self.write(PROTO + pack("= 4: + self.framer.start_framing() + self.save(obj) + self.write(STOP) + self.framer.end_framing() + + def memoize(self, obj): + """Store an object in the memo.""" + + # The Pickler memo is a dictionary mapping object ids to 2-tuples + # that contain the Unpickler memo key and the object being memoized. + # The memo key is written to the pickle and will become + # the key in the Unpickler's memo. The object is stored in the + # Pickler memo so that transient objects are kept alive during + # pickling. + + # The use of the Unpickler memo length as the memo key is just a + # convention. The only requirement is that the memo values be unique. + # But there appears no advantage to any other scheme, and this + # scheme allows the Unpickler memo to be implemented as a plain (but + # growable) array, indexed by memo key. + if self.fast: + return + assert id(obj) not in self.memo + idx = len(self.memo) + self.write(self.put(idx)) + self.memo[id(obj)] = idx, obj + + # Return a PUT (BINPUT, LONG_BINPUT) opcode string, with argument i. + def put(self, idx): + if self.proto >= 4: + return MEMOIZE + elif self.bin: + if idx < 256: + return BINPUT + pack("= 2 and func_name == "__newobj_ex__": + cls, args, kwargs = args + if not hasattr(cls, "__new__"): + raise PicklingError("args[0] from {} args has no __new__".format(func_name)) + if obj is not None and cls is not obj.__class__: + raise PicklingError("args[0] from {} args has the wrong class".format(func_name)) + if self.proto >= 4: + save(cls) + save(args) + save(kwargs) + write(NEWOBJ_EX) + else: + func = partial(cls.__new__, cls, *args, **kwargs) + save(func) + save(()) + write(REDUCE) + elif self.proto >= 2 and func_name == "__newobj__": + # A __reduce__ implementation can direct protocol 2 or newer to + # use the more efficient NEWOBJ opcode, while still + # allowing protocol 0 and 1 to work normally. For this to + # work, the function returned by __reduce__ should be + # called __newobj__, and its first argument should be a + # class. The implementation for __newobj__ + # should be as follows, although pickle has no way to + # verify this: + # + # def __newobj__(cls, *args): + # return cls.__new__(cls, *args) + # + # Protocols 0 and 1 will pickle a reference to __newobj__, + # while protocol 2 (and above) will pickle a reference to + # cls, the remaining args tuple, and the NEWOBJ code, + # which calls cls.__new__(cls, *args) at unpickling time + # (see load_newobj below). If __reduce__ returns a + # three-tuple, the state from the third tuple item will be + # pickled regardless of the protocol, calling __setstate__ + # at unpickling time (see load_build below). + # + # Note that no standard __newobj__ implementation exists; + # you have to provide your own. This is to enforce + # compatibility with Python 2.2 (pickles written using + # protocol 0 or 1 in Python 2.3 should be unpicklable by + # Python 2.2). + cls = args[0] + if not hasattr(cls, "__new__"): + raise PicklingError("args[0] from __newobj__ args has no __new__") + if obj is not None and cls is not obj.__class__: + raise PicklingError("args[0] from __newobj__ args has the wrong class") + args = args[1:] + save(cls) + save(args) + write(NEWOBJ) + else: + save(func) + save(args) + write(REDUCE) + + if obj is not None: + # If the object is already in the memo, this means it is + # recursive. In this case, throw away everything we put on the + # stack, and fetch the object back from the memo. + if id(obj) in self.memo: + write(POP + self.get(self.memo[id(obj)][0])) + else: + self.memoize(obj) + + # More new special cases (that work with older protocols as + # well): when __reduce__ returns a tuple with 4 or 5 items, + # the 4th and 5th item should be iterators that provide list + # items and dict items (as (key, value) tuples), or None. + + if listitems is not None: + self._batch_appends(listitems) + + if dictitems is not None: + self._batch_setitems(dictitems) + + if state is not None: + if state_setter is None: + save(state) + write(BUILD) + else: + # If a state_setter is specified, call it instead of load_build + # to update obj's with its previous state. + # First, push state_setter and its tuple of expected arguments + # (obj, state) onto the stack. + save(state_setter) + save(obj) # simple BINGET opcode as obj is already memoized. + save(state) + write(TUPLE2) + # Trigger a state_setter(obj, state) function call. + write(REDUCE) + # The purpose of state_setter is to carry-out an + # inplace modification of obj. We do not care about what the + # method might return, so its output is eventually removed from + # the stack. + write(POP) + + # Methods below this point are dispatched through the dispatch table + + dispatch = {} + + def save_none(self, obj): + self.write(NONE) + + dispatch[type(None)] = save_none + + def save_bool(self, obj): + if self.proto >= 2: + self.write(NEWTRUE if obj else NEWFALSE) + else: + self.write(TRUE if obj else FALSE) + + dispatch[bool] = save_bool + + def save_long(self, obj): + if self.bin: + # If the int is small enough to fit in a signed 4-byte 2's-comp + # format, we can store it more efficiently than the general + # case. + # First one- and two-byte unsigned ints: + if obj >= 0: + if obj <= 0xFF: + self.write(BININT1 + pack("= 2: + encoded = encode_long(obj) + n = len(encoded) + if n < 256: + self.write(LONG1 + pack("d", obj)) + else: + self.write(FLOAT + repr(obj).encode("ascii") + b"\n") + + dispatch[float] = save_float + + def save_bytes(self, obj): + if self.proto < 3: + if not obj: # bytes object is empty + self.save_reduce(bytes, (), obj=obj) + else: + self.save_reduce(codecs.encode, (str(obj, "latin1"), "latin1"), obj=obj) + return + n = len(obj) + if n <= 0xFF: + self.write(SHORT_BINBYTES + pack(" 0xFFFFFFFF and self.proto >= 4: + self._write_large_bytes(BINBYTES8 + pack("= self.framer._FRAME_SIZE_TARGET: + self._write_large_bytes(BINBYTES + pack("= self.framer._FRAME_SIZE_TARGET: + self._write_large_bytes(BYTEARRAY8 + pack("= 5") + with obj.raw() as m: + if not m.contiguous: + raise PicklingError("PickleBuffer can not be pickled when " "pointing to a non-contiguous buffer") + in_band = True + if self._buffer_callback is not None: + in_band = bool(self._buffer_callback(obj)) + if in_band: + # Write data in-band + # XXX The C implementation avoids a copy here + if m.readonly: + self.save_bytes(m.tobytes()) + else: + self.save_bytearray(m.tobytes()) + else: + # Write data out-of-band + self.write(NEXT_BUFFER) + if m.readonly: + self.write(READONLY_BUFFER) + + dispatch[PickleBuffer] = save_picklebuffer + + def save_str(self, obj): + if self.bin: + encoded = obj.encode("utf-8", "surrogatepass") + n = len(encoded) + if n <= 0xFF and self.proto >= 4: + self.write(SHORT_BINUNICODE + pack(" 0xFFFFFFFF and self.proto >= 4: + self._write_large_bytes(BINUNICODE8 + pack("= self.framer._FRAME_SIZE_TARGET: + self._write_large_bytes(BINUNICODE + pack("= 2: + for element in obj: + save(element) + # Subtle. Same as in the big comment below. + if id(obj) in memo: + get = self.get(memo[id(obj)][0]) + self.write(POP * n + get) + else: + self.write(_tuplesize2code[n]) + self.memoize(obj) + return + + # proto 0 or proto 1 and tuple isn't empty, or proto > 1 and tuple + # has more than 3 elements. + write = self.write + write(MARK) + for element in obj: + save(element) + + if id(obj) in memo: + # Subtle. d was not in memo when we entered save_tuple(), so + # the process of saving the tuple's elements must have saved + # the tuple itself: the tuple is recursive. The proper action + # now is to throw away everything we put on the stack, and + # simply GET the tuple (it's already constructed). This check + # could have been done in the "for element" loop instead, but + # recursive tuples are a rare thing. + get = self.get(memo[id(obj)][0]) + if self.bin: + write(POP_MARK + get) + else: # proto 0 -- POP_MARK not available + write(POP * (n + 1) + get) + return + + # No recursion. + write(TUPLE) + self.memoize(obj) + + dispatch[tuple] = save_tuple + + def save_list(self, obj): + if self.bin: + self.write(EMPTY_LIST) + else: # proto 0 -- can't use EMPTY_LIST + self.write(MARK + LIST) + + self.memoize(obj) + self._batch_appends(obj) + + dispatch[list] = save_list + + _BATCHSIZE = 1000 + + def _batch_appends(self, items): + # Helper to batch up APPENDS sequences + save = self.save + write = self.write + + if not self.bin: + for x in items: + save(x) + write(APPEND) + return + + it = iter(items) + while True: + tmp = list(islice(it, self._BATCHSIZE)) + n = len(tmp) + if n > 1: + write(MARK) + for x in tmp: + save(x) + write(APPENDS) + elif n: + save(tmp[0]) + write(APPEND) + # else tmp is empty, and we're done + if n < self._BATCHSIZE: + return + + def save_dict(self, obj): + if self.bin: + self.write(EMPTY_DICT) + else: # proto 0 -- can't use EMPTY_DICT + self.write(MARK + DICT) + + self.memoize(obj) + self._batch_setitems(obj.items()) + + dispatch[dict] = save_dict + if PyStringMap is not None: + dispatch[PyStringMap] = save_dict + + def _batch_setitems(self, items): + # Helper to batch up SETITEMS sequences; proto >= 1 only + save = self.save + write = self.write + + if not self.bin: + for k, v in items: + save(k) + save(v) + write(SETITEM) + return + + it = iter(items) + while True: + tmp = list(islice(it, self._BATCHSIZE)) + n = len(tmp) + if n > 1: + write(MARK) + for k, v in tmp: + save(k) + save(v) + write(SETITEMS) + elif n: + k, v = tmp[0] + save(k) + save(v) + write(SETITEM) + # else tmp is empty, and we're done + if n < self._BATCHSIZE: + return + + def save_set(self, obj): + save = self.save + write = self.write + + if self.proto < 4: + self.save_reduce(set, (list(obj),), obj=obj) + return + + write(EMPTY_SET) + self.memoize(obj) + + it = iter(obj) + while True: + batch = list(islice(it, self._BATCHSIZE)) + n = len(batch) + if n > 0: + write(MARK) + for item in batch: + save(item) + write(ADDITEMS) + if n < self._BATCHSIZE: + return + + dispatch[set] = save_set + + def save_frozenset(self, obj): + save = self.save + write = self.write + + if self.proto < 4: + self.save_reduce(frozenset, (list(obj),), obj=obj) + return + + write(MARK) + for item in obj: + save(item) + + if id(obj) in self.memo: + # If the object is already in the memo, this means it is + # recursive. In this case, throw away everything we put on the + # stack, and fetch the object back from the memo. + write(POP_MARK + self.get(self.memo[id(obj)][0])) + return + + write(FROZENSET) + self.memoize(obj) + + dispatch[frozenset] = save_frozenset + + def save_global(self, obj, name=None): + write = self.write + memo = self.memo + + if name is None: + name = getattr(obj, "__qualname__", None) + if name is None: + name = obj.__name__ + + module_name = whichmodule(obj, name) + try: + __import__(module_name, level=0) + module = sys.modules[module_name] + obj2, parent = _getattribute(module, name) + except (ImportError, KeyError, AttributeError): + raise PicklingError("Can't pickle %r: it's not found as %s.%s" % (obj, module_name, name)) from None + else: + if obj2 is not obj: + raise PicklingError("Can't pickle %r: it's not the same object as %s.%s" % (obj, module_name, name)) + + if self.proto >= 2: + code = _extension_registry.get((module_name, name)) + if code: + assert code > 0 + if code <= 0xFF: + write(EXT1 + pack("= 3. + if self.proto >= 4: + self.save(module_name) + self.save(name) + write(STACK_GLOBAL) + elif parent is not module: + self.save_reduce(getattr, (parent, lastname)) + elif self.proto >= 3: + write(GLOBAL + bytes(module_name, "utf-8") + b"\n" + bytes(name, "utf-8") + b"\n") + else: + if self.fix_imports: + r_name_mapping = _compat_pickle.REVERSE_NAME_MAPPING + r_import_mapping = _compat_pickle.REVERSE_IMPORT_MAPPING + if (module_name, name) in r_name_mapping: + module_name, name = r_name_mapping[(module_name, name)] + elif module_name in r_import_mapping: + module_name = r_import_mapping[module_name] + try: + write(GLOBAL + bytes(module_name, "ascii") + b"\n" + bytes(name, "ascii") + b"\n") + except UnicodeEncodeError: + raise PicklingError( + "can't pickle global identifier '%s.%s' using " "pickle protocol %i" % (module, name, self.proto) + ) from None + + self.memoize(obj) + + def save_type(self, obj): + if obj is type(None): + return self.save_reduce(type, (None,), obj=obj) + elif obj is type(NotImplemented): + return self.save_reduce(type, (NotImplemented,), obj=obj) + elif obj is type(...): + return self.save_reduce(type, (...,), obj=obj) + return self.save_global(obj) + + dispatch[FunctionType] = save_global + dispatch[type] = save_type + + +# Unpickling machinery + + +class _Unpickler: # pragma: no cover + def __init__(self, file, *, fix_imports=True, encoding="ASCII", errors="strict", buffers=None): + """This takes a binary file for reading a pickle data stream. + + The protocol version of the pickle is detected automatically, so + no proto argument is needed. + + The argument *file* must have two methods, a read() method that + takes an integer argument, and a readline() method that requires + no arguments. Both methods should return bytes. Thus *file* + can be a binary file object opened for reading, an io.BytesIO + object, or any other custom object that meets this interface. + + The file-like object must have two methods, a read() method + that takes an integer argument, and a readline() method that + requires no arguments. Both methods should return bytes. + Thus file-like object can be a binary file object opened for + reading, a BytesIO object, or any other custom object that + meets this interface. + + If *buffers* is not None, it should be an iterable of buffer-enabled + objects that is consumed each time the pickle stream references + an out-of-band buffer view. Such buffers have been given in order + to the *buffer_callback* of a Pickler object. + + If *buffers* is None (the default), then the buffers are taken + from the pickle stream, assuming they are serialized there. + It is an error for *buffers* to be None if the pickle stream + was produced with a non-None *buffer_callback*. + + Other optional arguments are *fix_imports*, *encoding* and + *errors*, which are used to control compatibility support for + pickle stream generated by Python 2. If *fix_imports* is True, + pickle will try to map the old Python 2 names to the new names + used in Python 3. The *encoding* and *errors* tell pickle how + to decode 8-bit string instances pickled by Python 2; these + default to 'ASCII' and 'strict', respectively. *encoding* can be + 'bytes' to read these 8-bit string instances as bytes objects. + """ + self._buffers = iter(buffers) if buffers is not None else None + self._file_readline = file.readline + self._file_read = file.read + self.memo = {} + self.encoding = encoding + self.errors = errors + self.proto = 0 + self.fix_imports = fix_imports + + def load(self, tensor_name=None): + """Read a pickled object representation from the open file. + + Return the reconstituted object hierarchy specified in the file. + """ + # Check whether Unpickler was initialized correctly. This is + # only needed to mimic the behavior of _pickle.Unpickler.dump(). + + if not hasattr(self, "_file_read"): + raise UnpicklingError( + "Unpickler.__init__() was not called by " "%s.__init__()" % (self.__class__.__name__,) + ) + self.tensor_name = tensor_name + self._unframer = _Unframer(self._file_read, self._file_readline) + self.read = self._unframer.read + self.readinto = self._unframer.readinto + self.readline = self._unframer.readline + self.metastack = [] + self.stack = [] + self.append = self.stack.append + self.proto = 0 + read = self.read + dispatch = self.dispatch + try: + while True: + key = read(1) + if not key: + raise EOFError + assert isinstance(key, bytes_types) + dispatch[key[0]](self) + except _Stop as stopinst: + return stopinst.value + + # Return a list of items pushed in the stack after last MARK instruction. + def pop_mark(self): + items = self.stack + self.stack = self.metastack.pop() + self.append = self.stack.append + return items + + def persistent_load(self, pid): + raise UnpicklingError("unsupported persistent id encountered") + + dispatch = {} + + def load_proto(self): + proto = self.read(1)[0] + if not 0 <= proto <= HIGHEST_PROTOCOL: + raise ValueError("unsupported pickle protocol: %d" % proto) + self.proto = proto + + dispatch[PROTO[0]] = load_proto + + def load_frame(self): + (frame_size,) = unpack(" sys.maxsize: + raise ValueError("frame size > sys.maxsize: %d" % frame_size) + self._unframer.load_frame(frame_size) + + dispatch[FRAME[0]] = load_frame + + def load_persid(self): + try: + pid = self.readline()[:-1].decode("ascii") + except UnicodeDecodeError: + raise UnpicklingError("persistent IDs in protocol 0 must be ASCII strings") + self.append(self.persistent_load(pid)) + + dispatch[PERSID[0]] = load_persid + + def load_binpersid(self): + pid = self.stack.pop() + self.append(self.persistent_load(pid)) + + dispatch[BINPERSID[0]] = load_binpersid + + def load_none(self): + self.append(None) + + dispatch[NONE[0]] = load_none + + def load_false(self): + self.append(False) + + dispatch[NEWFALSE[0]] = load_false + + def load_true(self): + self.append(True) + + dispatch[NEWTRUE[0]] = load_true + + def load_int(self): + data = self.readline() + if data == FALSE[1:]: + val = False + elif data == TRUE[1:]: + val = True + else: + val = int(data, 0) + self.append(val) + + dispatch[INT[0]] = load_int + + def load_binint(self): + self.append(unpack("d", self.read(8))[0]) + + dispatch[BINFLOAT[0]] = load_binfloat + + def _decode_string(self, value): + # Used to allow strings from Python 2 to be decoded either as + # bytes or Unicode strings. This should be used only with the + # STRING, BINSTRING and SHORT_BINSTRING opcodes. + if self.encoding == "bytes": + return value + else: + return value.decode(self.encoding, self.errors) + + def load_string(self): + data = self.readline()[:-1] + # Strip outermost quotes + if len(data) >= 2 and data[0] == data[-1] and data[0] in b"\"'": + data = data[1:-1] + else: + raise UnpicklingError("the STRING opcode argument must be quoted") + self.append(self._decode_string(codecs.escape_decode(data)[0])) + + dispatch[STRING[0]] = load_string + + def load_binstring(self): + # Deprecated BINSTRING uses signed 32-bit length + (len,) = unpack(" maxsize: + raise UnpicklingError("BINBYTES exceeds system's maximum size " "of %d bytes" % maxsize) + self.append(self.read(len)) + + dispatch[BINBYTES[0]] = load_binbytes + + def load_unicode(self): + self.append(str(self.readline()[:-1], "raw-unicode-escape")) + + dispatch[UNICODE[0]] = load_unicode + + def load_binunicode(self): + (len,) = unpack(" maxsize: + raise UnpicklingError("BINUNICODE exceeds system's maximum size " "of %d bytes" % maxsize) + self.append(str(self.read(len), "utf-8", "surrogatepass")) + + dispatch[BINUNICODE[0]] = load_binunicode + + def load_binunicode8(self): + (len,) = unpack(" maxsize: + raise UnpicklingError("BINUNICODE8 exceeds system's maximum size " "of %d bytes" % maxsize) + self.append(str(self.read(len), "utf-8", "surrogatepass")) + + dispatch[BINUNICODE8[0]] = load_binunicode8 + + def load_binbytes8(self): + (len,) = unpack(" maxsize: + raise UnpicklingError("BINBYTES8 exceeds system's maximum size " "of %d bytes" % maxsize) + self.append(self.read(len)) + + dispatch[BINBYTES8[0]] = load_binbytes8 + + def load_bytearray8(self): + (len,) = unpack(" maxsize: + raise UnpicklingError("BYTEARRAY8 exceeds system's maximum size " "of %d bytes" % maxsize) + b = bytearray(len) + self.readinto(b) + self.append(b) + + dispatch[BYTEARRAY8[0]] = load_bytearray8 + + def load_next_buffer(self): + if self._buffers is None: + raise UnpicklingError("pickle stream refers to out-of-band data " "but no *buffers* argument was given") + try: + buf = next(self._buffers) + except StopIteration: + raise UnpicklingError("not enough out-of-band buffers") + self.append(buf) + + dispatch[NEXT_BUFFER[0]] = load_next_buffer + + def load_readonly_buffer(self): + buf = self.stack[-1] + with memoryview(buf) as m: + if not m.readonly: + self.stack[-1] = m.toreadonly() + + dispatch[READONLY_BUFFER[0]] = load_readonly_buffer + + def load_short_binstring(self): + len = self.read(1)[0] + data = self.read(len) + self.append(self._decode_string(data)) + + dispatch[SHORT_BINSTRING[0]] = load_short_binstring + + def load_short_binbytes(self): + len = self.read(1)[0] + self.append(self.read(len)) + + dispatch[SHORT_BINBYTES[0]] = load_short_binbytes + + def load_short_binunicode(self): + len = self.read(1)[0] + self.append(str(self.read(len), "utf-8", "surrogatepass")) + + dispatch[SHORT_BINUNICODE[0]] = load_short_binunicode + + def load_tuple(self): + items = self.pop_mark() + self.append(tuple(items)) + + dispatch[TUPLE[0]] = load_tuple + + def load_empty_tuple(self): + self.append(()) + + dispatch[EMPTY_TUPLE[0]] = load_empty_tuple + + def load_tuple1(self): + self.stack[-1] = (self.stack[-1],) + + dispatch[TUPLE1[0]] = load_tuple1 + + def load_tuple2(self): + self.stack[-2:] = [(self.stack[-2], self.stack[-1])] + + dispatch[TUPLE2[0]] = load_tuple2 + + def load_tuple3(self): + self.stack[-3:] = [(self.stack[-3], self.stack[-2], self.stack[-1])] + + dispatch[TUPLE3[0]] = load_tuple3 + + def load_empty_list(self): + self.append([]) + + dispatch[EMPTY_LIST[0]] = load_empty_list + + def load_empty_dictionary(self): + self.append({}) + + dispatch[EMPTY_DICT[0]] = load_empty_dictionary + + def load_empty_set(self): + self.append(set()) + + dispatch[EMPTY_SET[0]] = load_empty_set + + def load_frozenset(self): + items = self.pop_mark() + self.append(frozenset(items)) + + dispatch[FROZENSET[0]] = load_frozenset + + def load_list(self): + items = self.pop_mark() + self.append(items) + + dispatch[LIST[0]] = load_list + + def load_dict(self): + items = self.pop_mark() + d = {items[i]: items[i + 1] for i in range(0, len(items), 2)} + self.append(d) + + dispatch[DICT[0]] = load_dict + + # INST and OBJ differ only in how they get a class object. It's not + # only sensible to do the rest in a common routine, the two routines + # previously diverged and grew different bugs. + # klass is the class to instantiate, and k points to the topmost mark + # object, following which are the arguments for klass.__init__. + def _instantiate(self, klass, args): + if args or not isinstance(klass, type) or hasattr(klass, "__getinitargs__"): + try: + value = klass(*args) + except TypeError as err: + raise TypeError("in constructor for %s: %s" % (klass.__name__, str(err)), sys.exc_info()[2]) + else: + value = klass.__new__(klass) + self.append(value) + + def load_inst(self): + module = self.readline()[:-1].decode("ascii") + name = self.readline()[:-1].decode("ascii") + klass = self.find_class(module, name) + self._instantiate(klass, self.pop_mark()) + + dispatch[INST[0]] = load_inst + + def load_obj(self): + # Stack is ... markobject classobject arg1 arg2 ... + args = self.pop_mark() + cls = args.pop(0) + self._instantiate(cls, args) + + dispatch[OBJ[0]] = load_obj + + def load_newobj(self): + args = self.stack.pop() + cls = self.stack.pop() + obj = cls.__new__(cls, *args) + self.append(obj) + + dispatch[NEWOBJ[0]] = load_newobj + + def load_newobj_ex(self): + kwargs = self.stack.pop() + args = self.stack.pop() + cls = self.stack.pop() + obj = cls.__new__(cls, *args, **kwargs) + self.append(obj) + + dispatch[NEWOBJ_EX[0]] = load_newobj_ex + + def load_global(self): + module = self.readline()[:-1].decode("utf-8") + name = self.readline()[:-1].decode("utf-8") + klass = self.find_class(module, name) + self.append(klass) + + dispatch[GLOBAL[0]] = load_global + + def load_stack_global(self): + name = self.stack.pop() + module = self.stack.pop() + if type(name) is not str or type(module) is not str: + raise UnpicklingError("STACK_GLOBAL requires str") + self.append(self.find_class(module, name)) + + dispatch[STACK_GLOBAL[0]] = load_stack_global + + def load_ext1(self): + code = self.read(1)[0] + self.get_extension(code) + + dispatch[EXT1[0]] = load_ext1 + + def load_ext2(self): + (code,) = unpack("= 4: + return _getattribute(sys.modules[module], name)[0] + else: + return getattr(sys.modules[module], name) + + def load_reduce(self): + stack = self.stack + args = stack.pop() + func = stack[-1] + if len(args) > 0 and args[0] is None: + stack[-1] = None + else: + stack[-1] = func(*args) + # stack[-1] = func(*args) + + dispatch[REDUCE[0]] = load_reduce + + def load_pop(self): + if self.stack: + del self.stack[-1] + else: + self.pop_mark() + + dispatch[POP[0]] = load_pop + + def load_pop_mark(self): + self.pop_mark() + + dispatch[POP_MARK[0]] = load_pop_mark + + def load_dup(self): + self.append(self.stack[-1]) + + dispatch[DUP[0]] = load_dup + + def load_get(self): + i = int(self.readline()[:-1]) + self.append(self.memo[i]) + + dispatch[GET[0]] = load_get + + def load_binget(self): + i = self.read(1)[0] + self.append(self.memo[i]) + + dispatch[BINGET[0]] = load_binget + + def load_long_binget(self): + (i,) = unpack(" maxsize: + raise ValueError("negative LONG_BINPUT argument") + self.memo[i] = self.stack[-1] + + dispatch[LONG_BINPUT[0]] = load_long_binput + + def load_memoize(self): + memo = self.memo + memo[len(memo)] = self.stack[-1] + + dispatch[MEMOIZE[0]] = load_memoize + + def load_append(self): + stack = self.stack + value = stack.pop() + list = stack[-1] + list.append(value) + + dispatch[APPEND[0]] = load_append + + def load_appends(self): + items = self.pop_mark() + list_obj = self.stack[-1] + try: + extend = list_obj.extend + except AttributeError: + pass + else: + extend(items) + return + # Even if the PEP 307 requires extend() and append() methods, + # fall back on append() if the object has no extend() method + # for backward compatibility. + append = list_obj.append + for item in items: + append(item) + + dispatch[APPENDS[0]] = load_appends + + def load_setitem(self): + stack = self.stack + value = stack.pop() + key = stack.pop() + dict = stack[-1] + dict[key] = value + + dispatch[SETITEM[0]] = load_setitem + + def load_setitems(self): + items = self.pop_mark() + dict = self.stack[-1] + for i in range(0, len(items), 2): + dict[items[i]] = items[i + 1] + + dispatch[SETITEMS[0]] = load_setitems + + def load_additems(self): + items = self.pop_mark() + set_obj = self.stack[-1] + if isinstance(set_obj, set): + set_obj.update(items) + else: + add = set_obj.add + for item in items: + add(item) + + dispatch[ADDITEMS[0]] = load_additems + + def load_build(self): + stack = self.stack + state = stack.pop() + inst = stack[-1] + setstate = getattr(inst, "__setstate__", None) + if setstate is not None: + setstate(state) + return + slotstate = None + if isinstance(state, tuple) and len(state) == 2: + state, slotstate = state + if state: + inst_dict = inst.__dict__ + intern = sys.intern + for k, v in state.items(): + if type(k) is str: + inst_dict[intern(k)] = v + else: + inst_dict[k] = v + if slotstate: + for k, v in slotstate.items(): + setattr(inst, k, v) + + dispatch[BUILD[0]] = load_build + + def load_mark(self): + self.metastack.append(self.stack) + self.stack = [] + self.append = self.stack.append + + dispatch[MARK[0]] = load_mark + + def load_stop(self): + value = self.stack.pop() + raise _Stop(value) + + dispatch[STOP[0]] = load_stop + + +# Shorthands + + +def _dump(obj, file, protocol=None, *, fix_imports=True, buffer_callback=None): # pragma: no cover + _Pickler(file, protocol, fix_imports=fix_imports, buffer_callback=buffer_callback).dump(obj) + + +def _dumps(obj, protocol=None, *, fix_imports=True, buffer_callback=None): # pragma: no cover + f = io.BytesIO() + _Pickler(f, protocol, fix_imports=fix_imports, buffer_callback=buffer_callback).dump(obj) + res = f.getvalue() + assert isinstance(res, bytes_types) + return res + + +def _load(file, *, fix_imports=True, encoding="ASCII", errors="strict", buffers=None): # pragma: no cover + return _Unpickler(file, fix_imports=fix_imports, buffers=buffers, encoding=encoding, errors=errors).load() + + +def _loads(s, *, fix_imports=True, encoding="ASCII", errors="strict", buffers=None): # pragma: no cover + if isinstance(s, str): + raise TypeError("Can't load pickle from unicode string") + file = io.BytesIO(s) + return _Unpickler(file, fix_imports=fix_imports, buffers=buffers, encoding=encoding, errors=errors).load() + + +# Use the faster _pickle if possible +Pickler, Unpickler = _Pickler, _Unpickler +dump, dumps, load, loads = _dump, _dumps, _load, _loads + + +# Doctest +def _test(): # pragma: no cover + import doctest + + return doctest.testmod() + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="display contents of the pickle files") + parser.add_argument("pickle_file", type=argparse.FileType("br"), nargs="*", help="the pickle file") + parser.add_argument("-t", "--test", action="store_true", help="run self-test suite") + parser.add_argument("-v", action="store_true", help="run verbosely; only affects self-test run") + args = parser.parse_args() + if args.test: + _test() + else: + if not args.pickle_file: + parser.print_help() + else: + import pprint + + for f in args.pickle_file: + obj = load(f) + pprint.pprint(obj) diff --git a/neural_compressor/torch/algorithms/layer_wise/utils.py b/neural_compressor/torch/algorithms/layer_wise/utils.py new file mode 100644 index 00000000000..464a25cdee0 --- /dev/null +++ b/neural_compressor/torch/algorithms/layer_wise/utils.py @@ -0,0 +1,296 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utils for layer wise quantization.""" + +import gc +import json +import os + +import torch +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device +from transformers import AutoConfig, AutoModelForCausalLM +from transformers.models.auto.auto_factory import _BaseAutoModelClass + +from neural_compressor.common import options + +from .load import load + +LWQ_WORKSPACE = os.path.join(options.workspace, "layer_wise_tmp") + + +class QDQLayer(torch.nn.Module): + def __init__(self, module, input_scale=None) -> None: + super().__init__() + self.quant = torch.ao.quantization.QuantStub() + self.module = module + self.dequant = torch.ao.quantization.DeQuantStub() + self.input_scale = input_scale + + def forward(self, X): + if self.input_scale is not None: + X = torch.mul(X, self.input_scale) + X = self.quant(X) + X = self.module(X) + X = self.dequant(X) + return X + + +def get_module(model, key): + """Get module from model by key name. + + Args: + model (torch.nn.Module): original model + key (str): module name to be replaced + """ + attrs = key.split(".") + module = model + for attr in attrs: + try: + attr = int(attr) + module = module[attr] + except: + module = getattr(module, attr) + return module + + +def get_children(model): + """Get all the children of given model.""" + module_list = [] + children = list(model.children()) + if len(children) == 0: + return [model] + for child in children: + module_list += get_children(child) + return module_list + + +def get_named_children(model, pre=[]): + """Get all the name and children of given model.""" + module_list = [] + if len(list(model.children())) == 0: + return [(".".join(pre), model)] + for name, module in model.named_children(): + module_list += get_named_children(module, pre=pre + [name]) + return module_list + + +def dowload_hf_model(repo_id, cache_dir=None, repo_type=None, revision=None): + """Download hugging face model from hf hub.""" + from huggingface_hub.constants import DEFAULT_REVISION, HUGGINGFACE_HUB_CACHE + from huggingface_hub.file_download import REGEX_COMMIT_HASH, repo_folder_name + from huggingface_hub.utils import EntryNotFoundError + + if cache_dir is None: + cache_dir = HUGGINGFACE_HUB_CACHE + if revision is None: + revision = DEFAULT_REVISION + if repo_type is None: + repo_type = "model" + storage_folder = os.path.join(cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type)) + commit_hash = None + if REGEX_COMMIT_HASH.match(revision): + commit_hash = revision + else: + ref_path = os.path.join(storage_folder, "refs", revision) + if os.path.exists(ref_path): + with open(ref_path) as f: + commit_hash = f.read() + if storage_folder and commit_hash: + pointer_path = os.path.join(storage_folder, "snapshots", commit_hash) + if os.path.isdir(pointer_path): + return pointer_path + else: # pragma: no cover + from huggingface_hub import snapshot_download + + file_path = snapshot_download(repo_id) + return file_path + + +def load_empty_model(pretrained_model_name_or_path, cls=AutoModelForCausalLM, **kwargs): + """Load a empty model.""" + is_local = os.path.isdir(pretrained_model_name_or_path) + if is_local: # pragma: no cover + path = pretrained_model_name_or_path + else: + path = dowload_hf_model(pretrained_model_name_or_path) + if cls.__base__ == _BaseAutoModelClass: + config = AutoConfig.from_pretrained(path, **kwargs) + with init_empty_weights(): + model = cls.from_config(config) + else: # pragma: no cover + config = cls.config_class.from_pretrained(path, **kwargs) + with init_empty_weights(): + model = cls(config) + model.tie_weights() + model.eval() + model.path = pretrained_model_name_or_path + return model + + +def get_super_module_by_name(model, module_name): + """Get the father module with given name of child module.""" + name_list = module_name.split(".") + for name in name_list[:-1]: + if hasattr(model, name): + model = getattr(model, name) + else: # pragma: no cover + return None + if hasattr(model, name_list[-1]): + return model + else: # pragma: no cover + return None + + +def update_module(model, module_name, new_module): + """Update module.""" + super_module = get_super_module_by_name(model, module_name) + if super_module: + setattr(super_module, module_name.split(".")[-1], new_module) + + +def load_layer_wise_quantized_model(path): # pragma: no cover + """Load layer wise quantized model.""" + model = torch.load(os.path.join(path, "model_arch.pt")) + for name, _ in model.named_modules(): + if name + ".pt" in os.listdir(path): + update_module(model, name, torch.load(os.path.join(path, name + ".pt"))) + model.eval() + return model + + +def load_tensor_from_shard(pretrained_model_name_or_path, tensor_name, prefix=None): # pragma: no cover + """Load tensor from shard.""" + path = _get_path(pretrained_model_name_or_path) + idx_dict = json.load(open(os.path.join(path, "pytorch_model.bin.index.json"), "r"))["weight_map"] + if tensor_name not in idx_dict.keys(): + if tensor_name.replace(f"{prefix}.", "") in idx_dict.keys(): + tensor_name = tensor_name.replace(f"{prefix}.", "") + else: + assert False, "{} not in the index.json".format(tensor_name) + return load_tensor(os.path.join(path, idx_dict[tensor_name]), tensor_name, None) + + +def load_tensor(path, tensor_name=None, prefix=None): + """Load a tensor from bin file with given tensor name.""" + # transformers.modeling_utils + if tensor_name: + if "gamma" in tensor_name: # pragma: no cover + tensor_name = tensor_name.replace("gamma", "weight") + if "beta" in tensor_name: # pragma: no cover + tensor_name = tensor_name.replace("beta", "bias") + + if os.path.isdir(path): + path = os.path.join(path, "pytorch_model.bin") + state_dict = load(path, tensor_name, prefix) + if tensor_name: + if tensor_name in state_dict: + return state_dict[tensor_name] + else: # pragma: no cover + return state_dict[tensor_name.replace(f"{prefix}.", "")] + else: # pragma: no cover + return state_dict + + +def _get_path(pretrained_model_name_or_path): + is_local = os.path.isdir(pretrained_model_name_or_path) + if is_local: # pragma: no cover + path = pretrained_model_name_or_path + else: + path = dowload_hf_model(pretrained_model_name_or_path) + return path + + +def load_value(model, param_name, path): + if "lm_head" in param_name and getattr(model.config, "tie_word_embeddings", True): + input_embeddings = model.get_input_embeddings() + modules = get_named_children(model) + for name, module in modules: + if module == input_embeddings: + param_name = name + "." + param_name.split(".")[-1] + prefix = model.base_model_prefix + if "pytorch_model.bin.index.json" in os.listdir(path): + value = load_tensor_from_shard(path, param_name, prefix) + else: + value = load_tensor(os.path.join(path, "pytorch_model.bin"), param_name, prefix) + return value + + +def load_module(model, module_name, path, device="cpu"): + module = get_module(model, module_name) + for n, p in module.named_parameters(): + param_name = module_name + "." + n + value = load_value(model, param_name, path) + set_module_tensor_to_device(model, param_name, device, value) + + +def register_weight_hooks(model, path, device="cpu", clean_weight=True, saved_path=None): + if saved_path: + os.makedirs(saved_path, exist_ok=True) + + def forward_pre_hook(name): + def hook(module, input): + state_dict = None + if os.path.exists(os.path.join(LWQ_WORKSPACE, f"{name}.pt")): + state_dict = torch.load(os.path.join(LWQ_WORKSPACE, f"{name}.pt")) + for n, p in module.named_parameters(): + param_name = name + "." + n + if state_dict: + value = state_dict[n] + else: + value = load_value(model, param_name, path) + set_module_tensor_to_device(model, param_name, device, value) + + return hook + + def forward_hook(name): + def hook(module, input, output): + if saved_path: + file_path = os.path.join(saved_path, f"{name}.pt") + torch.save(module.state_dict(), file_path) + clean_module_weight(module) + + return hook + + handle = {} + modules = get_named_children(model) + for name, module in modules: + handle[name] = [module.register_forward_pre_hook(forward_pre_hook(name))] + if clean_weight: + handle[name] += [module.register_forward_hook(forward_hook(name))] + return handle + + +def clean_module_weight(module): + if isinstance(module, QDQLayer): + submodule = module.module + else: + submodule = module + + for n, m in submodule.named_parameters(): + is_buffer = n in submodule._buffers + old_value = getattr(submodule, n) + with torch.no_grad(): + if is_buffer: + submodule._buffers[n] = torch.zeros(old_value.shape, device="meta") + else: + param_cls = type(submodule._parameters[n]) + kwargs = submodule._parameters[n].__dict__ + new_value = torch.zeros(old_value.shape, device="meta") + new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to("meta") + submodule._parameters[n] = new_value + gc.collect() diff --git a/neural_compressor/torch/algorithms/weight_only/__init__.py b/neural_compressor/torch/algorithms/weight_only/__init__.py index ac8feca4f40..032dab931b5 100644 --- a/neural_compressor/torch/algorithms/weight_only/__init__.py +++ b/neural_compressor/torch/algorithms/weight_only/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .utility import * from .rtn import rtn_quantize from .gptq import gptq_quantize +from .modules import WeightOnlyLinear +from .utility import * diff --git a/neural_compressor/torch/algorithms/weight_only/gptq.py b/neural_compressor/torch/algorithms/weight_only/gptq.py index 339acbcec20..cf90b5c7048 100644 --- a/neural_compressor/torch/algorithms/weight_only/gptq.py +++ b/neural_compressor/torch/algorithms/weight_only/gptq.py @@ -15,8 +15,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Copied from neural_compressor/adaptor/torch_utils/gptq.py - import gc import math import random @@ -30,7 +28,9 @@ import transformers from tqdm import tqdm -from neural_compressor.torch.utils import logger +from neural_compressor.torch.utils import fetch_module, logger, set_module + +from .modules import WeightOnlyLinear DEBUG = False @@ -193,12 +193,15 @@ def __init__( self, model, weight_config={}, + dataloader=None, nsamples=128, - dataloader_len=10, use_max_length=True, - pad_max_length=2048, + max_seq_length=2048, device=None, - layer_wise=False, + export_compressed_model=False, + use_layer_wise=False, + model_path="", + run_fn=None, *args, **kwargs, ): @@ -218,6 +221,10 @@ def __init__( ... } dataloader: an iterable containing calibration datasets, contains (inputs, targets) + export_compressed_model (bool, optional): Choose return fp32 or int32 model. Defaults to False. + use_layer_wise (bool): Enables quantize model per layer. Defaults to False. + model_path (str): Model path that is used to load state_dict per layer. + run_fn: a function to run model inference for collecting input information. device: cpu or cuda """ # model @@ -230,16 +237,18 @@ def __init__( # weight config self.weight_config = weight_config # default settings, check configs - self.wdtype_default = "int" - self.wbits_default = 4 + self.dtype_default = "int" + self.bits_default = 4 self.group_size_default = 128 self.block_size_default = 128 self.percdamp_default = 0.01 self.sym_default = False self.act_order_default = False + self.static_groups_default = False self.perchannel_default = True self.mse_default = False - self.double_quant_dtype_default = "fp32" + self.use_double_quant_default = False + self.double_quant_dtype_default = "int" self.double_quant_bits_default = 4 self.double_quant_group_size_default = 128 self.double_quant_sym_default = False @@ -251,31 +260,148 @@ def __init__( self.device = self.model.device self.is_ready = False - self.layer_wise = layer_wise + self.export_compressed_model = export_compressed_model + self.use_layer_wise = use_layer_wise + self.model_path = model_path # dataloader self.use_max_length = use_max_length - self.pad_max_length = pad_max_length - self.dataloader_original = None + self.max_seq_length = max_seq_length + self.dataloader_original = dataloader self.dataloader = [] - self.dataloader_len = dataloader_len self.nsamples = nsamples - self.args = args - self.kwargs = kwargs - self.run_fn = self.kwargs.get("run_fn", None) - self.run_args = self.kwargs.get("run_args", None) - self.dataloader_len = dataloader_len - # compare 2.x, use run_fn to calibration - # self.prepare_dataloader() - self._post_init() - - def _post_init(self): - self.cache_key_arguments = { - "i": 0 - } # a dict of list, keyword arguments ("attention_masks", "position_ids", etc.) - # Note that the first elements in cache_positional_arguments is main input: hidden_states - self.cache_positional_arguments = [] # a list of list, positional arguments ("rotary_pos_emb" in chatglm) - self.is_ready = True + self.run_fn = run_fn + self.run_args = kwargs.get("run_args", None) + if run_fn is None: + self.prepare_dataloader() + + def prepare_dataloader(self): + if self.use_max_length: + # (Recommend) only take sequence whose length exceeds self.max_seq_length, + # which preserves calibration's tokens are all valid + # This is GPTQ official dataloader implementation + self.obtain_first_n_samples_fulllength() + else: + # general selection, no padding, not GPTQ original implementation. + self.obtain_first_n_samples() + + def obtain_first_n_samples(self, seed=0): + """Get first nsample data as the real calibration dataset.""" + self.dataloader.clear() + random.seed(seed) + for batch in self.dataloader_original: + # process data, depends on its data type. + if len(self.dataloader) == self.nsamples: + logger.info(f"Successfully collect {self.nsamples} calibration samples.") + break + # list, tuple + if isinstance(batch, list) or isinstance(batch, tuple): + if batch[0].shape[-1] > self.max_seq_length: + i = random.randint(0, batch[0].shape[-1] - self.max_seq_length - 1) + j = i + self.max_seq_length + batch_final = [] + for item in batch: + if isinstance(item, torch.Tensor) and item.shape.__len__() == 2: + batch_final.append(item[:, i:j]) + else: + batch_final.append(item) + else: + batch_final = batch[:] + # dict + elif isinstance(batch, dict): + try: + length = batch["input_ids"].shape[-1] + except: + logger.warning("Please make sure your dict'like data contains key of 'input_ids'.") + continue + batch_final = {} + if length > self.max_seq_length: + i = random.randint(0, length - self.max_seq_length - 1) + j = i + self.max_seq_length + # may have to slice every sequence related data + for key in batch.keys(): + if isinstance(batch[key], torch.Tensor): + batch_final[key] = batch[key][:, i:j] # slice on sequence length dim + else: + batch_final[key] = batch[key] + else: + batch_final = batch + # tensor + else: + if batch.shape[-1] > self.max_seq_length: + i = random.randint(0, batch.shape[-1] - self.max_seq_length - 1) + j = i + self.max_seq_length + batch_final = batch[:, i:j] + else: + batch_final = batch + self.dataloader.append(batch_final) + + if len(self.dataloader) < self.nsamples: + logger.warning(f"Try to use {self.nsamples} data, but entire dataset size is {len(self.dataloader)}.") + + def obtain_first_n_samples_fulllength(self, seed=0): + self.dataloader.clear() + random.seed(seed) + unified_length = self.max_seq_length + for batch in self.dataloader_original: + if len(self.dataloader) == self.nsamples: + logger.info(f"Successfully collect {self.nsamples} calibration samples.") + break + # list & tuple, gpt-j-6b mlperf, etc. + if isinstance(batch, list) or isinstance(batch, tuple): + if batch[0].shape[-1] == unified_length: + batch_final = batch[:] + elif batch[0].shape[-1] > unified_length: + i = random.randint(0, batch[0].shape[-1] - unified_length - 1) + j = i + unified_length + batch_final = [] + for item in batch: + if isinstance(item, torch.Tensor) and item.shape.__len__() == 2: + batch_final.append(item[:, i:j]) + else: + batch_final.append(item) + else: + # not match max length, not include in target dataset + continue + # dict + elif isinstance(batch, dict): + try: + length = batch["input_ids"].shape[-1] + except: + logger.warning("Please make sure your dict'like data contains key of 'input_ids'.") + continue + batch_final = {} + if length == self.max_seq_length: + batch_final = batch + elif length > self.max_seq_length: + i = random.randint(0, length - self.max_seq_length - 1) + j = i + self.max_seq_length + # may have to slice every sequence related data + for key in batch.keys(): + if isinstance(batch[key], torch.Tensor): + batch_final[key] = batch[key][:, i:j] # slice on sequence length dim with same position + else: + batch_final[key] = batch[key] + else: + # not match max length, not include in target dataset + continue + # tensor + else: + if batch.shape[-1] == unified_length: + batch_final = batch + elif batch.shape[-1] > unified_length: + i = random.randint(0, batch.shape[-1] - unified_length - 1) + j = i + unified_length + batch_final = batch[:, i:j] + else: + # not match max length, not include in target dataset + continue + self.dataloader.append(batch_final) + if len(self.dataloader) < self.nsamples: # pragma: no cover + logger.warning( + f"Trying to allocate {self.nsamples} data with fixed length {unified_length}, \ + but only {len(self.dataloader)} samples are found. Please use smaller 'self.max_seq_length' value." + ) def get_full_layer_name(self, sub_layer_name, block_idx): transformer_name = self.gptq_related_blocks["transformers_name"] @@ -283,46 +409,45 @@ def get_full_layer_name(self, sub_layer_name, block_idx): def check_layer_config(self): """Copy arguments from weight_config to built-in attributes.""" - if "wbits" in self.weight_config: - tmp_weight_config = {} + if not self.weight_config: for name, module in self.model.named_modules(): - tmp_weight_config[name] = {} - tmp_weight_config[name]["wdtype"] = self.weight_config.get("wdtype", self.wdtype_default) - tmp_weight_config[name]["wbits"] = self.weight_config.get("wbits", self.wbits_default) - tmp_weight_config[name]["group_size"] = self.weight_config.get("group_size", self.group_size_default) - tmp_weight_config[name]["block_size"] = self.weight_config.get("block_size", self.group_size_default) - tmp_weight_config[name]["percdamp"] = self.weight_config.get("pecdamp", self.percdamp_default) - tmp_weight_config[name]["sym"] = self.weight_config.get("sym", self.sym_default) - tmp_weight_config[name]["act_order"] = self.weight_config.get("act_order", self.act_order_default) - tmp_weight_config[name]["perchannel"] = self.weight_config.get("perchannel", self.perchannel_default) - tmp_weight_config[name]["mse"] = self.weight_config.get("mse", self.mse_default) - tmp_weight_config[name]["double_quant_dtype"] = self.weight_config.get( - "double_quant_dtype", self.double_quant_dtype_default - ) - tmp_weight_config[name]["double_quant_bits"] = self.weight_config.get( - "double_quant_bits", self.double_quant_bits_default - ) - tmp_weight_config[name]["double_quant_group_size"] = self.weight_config.get( - "double_quant_group_size", self.double_quant_group_size_default - ) - tmp_weight_config[name]["double_quant_sym"] = self.weight_config.get( - "double_quant_sym", self.double_quant_sym_default - ) - self.weight_config = tmp_weight_config + self.weight_config[name] = { + "dtype": self.dtype_default, + "bits": self.bits_default, + "sym": self.sym_default, + "group_size": self.group_size_default, + "mse": self.mse_default, + "perchannel": self.perchannel_default, + "act_order": self.act_order_default, + "percdamp": self.percdamp_default, + "block_size": self.block_size_default, + "static_groups": self.static_groups_default, + "use_double_quant": self.use_double_quant_default, + "double_quant_dtype": self.double_quant_dtype_default, + "double_quant_bits": self.double_quant_bits_default, + "double_quant_sym": self.double_quant_sym_default, + "double_quant_group_size": self.double_quant_group_size_default, + } else: for layer_name, config in self.weight_config.items(): - self.weight_config[layer_name]["wdtype"] = config.get("wdtype", self.wdtype_default) - self.weight_config[layer_name]["wbits"] = config.get("wbits", self.wbits_default) + self.weight_config[layer_name]["dtype"] = config.get("dtype", self.dtype_default) + self.weight_config[layer_name]["bits"] = config.get("bits", self.bits_default) self.weight_config[layer_name]["group_size"] = config.get("group_size", self.group_size_default) self.weight_config[layer_name]["block_size"] = config.get("block_size", self.group_size_default) - self.weight_config[layer_name]["percdamp"] = config.get("pecdamp", self.percdamp_default) + self.weight_config[layer_name]["percdamp"] = config.get("percdamp", self.percdamp_default) self.weight_config[layer_name]["sym"] = config.get("sym", self.sym_default) self.weight_config[layer_name]["act_order"] = config.get("act_order", self.act_order_default) + self.weight_config[layer_name]["static_groups"] = config.get( + "static_groups", self.static_groups_default + ) self.weight_config[layer_name]["perchannel"] = config.get("perchannel", self.perchannel_default) self.weight_config[layer_name]["mse"] = config.get("mse", self.mse_default) + self.weight_config[layer_name]["use_double_quant"] = config.get( + "use_double_quant", self.use_double_quant_default + ) self.weight_config[layer_name]["double_quant_dtype"] = config.get( "double_quant_dtype", self.double_quant_dtype_default - ) + ) # only support int self.weight_config[layer_name]["double_quant_bits"] = config.get( "double_quant_bits", self.double_quant_bits_default ) @@ -332,6 +457,12 @@ def check_layer_config(self): self.weight_config[layer_name]["double_quant_sym"] = config.get( "double_quant_sym", self.double_quant_sym_default ) + if ( + self.weight_config[layer_name]["dtype"] != "int" + and "int" in self.weight_config[layer_name]["dtype"] + ): + self.weight_config[layer_name]["bits"] = int(self.weight_config[layer_name]["dtype"].lstrip("int")) + self.weight_config[layer_name]["dtype"] = "int" def get_layer_config(self, layer_name): """Obtain config for one layer, since GPTQ supports layer-wise config.""" @@ -359,11 +490,21 @@ def track_hidden_states(self, data): @torch.no_grad() def pre_quantization(self): """Prepare input calibration data and other attributes which are critical for gptq execution.""" + try: + self.cache_key_arguments = { + "batch_num": 0 + } # a dict of list, keyword arguments ("attention_masks", "position_ids", etc.) + # Note that the first elements in cache_positional_arguments is main input: hidden_states + self.cache_positional_arguments = [] # a list of list, positional arguments ("rotary_pos_emb" in chatglm) + self.is_ready = True + except: + logger.warning("GPTQ Quantizer initialization failed!") + pass # critical: hooker function which collects inputs def forward(layer, *args, **kwargs): # inputs[inputs_info['idx']] = input_ids # TODO solve the problem of batchsize!=1 - self.cache_key_arguments["i"] += 1 + self.cache_key_arguments["batch_num"] += 1 for arg in kwargs: # TODO: investigate include parameters # each outputs can be different shape, hence also use list to store @@ -383,12 +524,12 @@ def forward(layer, *args, **kwargs): raise ValueError # Step1: fetch the embeddings and other layers before the transformer stack. - if not self.layer_wise: + if not self.use_layer_wise: for embedding_name, embedding_layer in self.gptq_related_blocks["embeddings"].items(): embedding_layer = embedding_layer.to(self.device) # Step2: modify the first transformer block's forward function to obtain inputs for calibration - if not self.layer_wise: + if not self.use_layer_wise: self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].to(self.device) forward_cache = self.gptq_related_blocks["transformers"][0].forward self.gptq_related_blocks["transformers"][0].forward = partial( @@ -398,23 +539,24 @@ def forward(layer, *args, **kwargs): # Step3: run forward to obtain calibration datasets logger.info("Collecting calibration inputs...") logger.info("Collecting calibration inputs by running the run_fn provided by user.") - if self.run_args: - self.run_fn(self.model, self.run_args) + if self.run_fn: + if self.run_args: + self.run_fn(self.model, *self.run_args) + else: + self.run_fn(self.model) else: - self.run_fn(self.model) - - # for batch in tqdm(self.dataloader): - # if not self.layer_wise: - # batch = move_input_to_device(batch, self.device) - # try: - # if isinstance(batch, tuple) or isinstance(batch, list): - # self.model(batch[0]) - # elif isinstance(batch, dict): - # self.model(**batch) - # else: - # self.model(batch) - # except ValueError: - # pass + for batch in tqdm(self.dataloader): + if not self.use_layer_wise: + batch = move_input_to_device(batch, self.device) + try: + if isinstance(batch, tuple) or isinstance(batch, list): + self.model(batch[0]) + elif isinstance(batch, dict): + self.model(**batch) + else: + self.model(batch) + except ValueError: + pass # output inp data shape logger.info("All calibration data's shape =>") # check all hidden_states shape @@ -427,7 +569,7 @@ def forward(layer, *args, **kwargs): # Step 4: restore original forward function, relocate layers back to cpu. self.gptq_related_blocks["transformers"][0].forward = forward_cache - if not self.layer_wise: + if not self.use_layer_wise: self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].cpu() for embedding_name, embedding_layer in self.gptq_related_blocks["embeddings"].items(): embedding_layer.to(self.device) @@ -456,20 +598,24 @@ def update_blockwise_hidden_states(self, outs): self.cache_positional_arguments[0] = outs[:] @torch.no_grad() - def execute_quantization(self, means=None, stds=None, model_path=None): + def execute_quantization(self, means=None, stds=None): """Run quantization.""" # Step1: prepare quantization (calibration datasets) logger.info("Begin ====>") self.pre_quantization() + model_path = self.model_path # Step2: run gptq quantization in a transformer block-wise manner. gptq_config = {} tblock_length = len(self.gptq_related_blocks["transformers"]) for block_idx in range(tblock_length): logger.info(f"Quantizing layer {block_idx + 1} / {tblock_length}..") - # if we do not apply layer-wise feature, we still place the entire block on the GPU - transformer_block = self.gptq_related_blocks["transformers"][block_idx].to(self.device) + if not self.use_layer_wise: + # if we do not apply layer-wise feature, we still place the entire block on the GPU + transformer_block = self.gptq_related_blocks["transformers"][block_idx].to(self.device) + else: + transformer_block = self.gptq_related_blocks["transformers"][block_idx] # .to(self.device) # Step2.1: obtain all layers (Linear, Conv2d, etc) in the block which can be quantized. sub_layers = find_layers(transformer_block) sub_layers_to_quant = {} @@ -492,7 +638,13 @@ def execute_quantization(self, means=None, stds=None, model_path=None): # ) full_layer_name = self.get_full_layer_name(layer_name, block_idx) weight_config_this_layer = self.get_layer_config(full_layer_name) - W = sub_layers[layer_name].weight.data.clone() + if self.use_layer_wise: + from neural_compressor.torch.algorithms.layer_wise import load_value + + W = load_value(self.model, full_layer_name + ".weight", model_path) + else: + W = sub_layers[layer_name].weight.data.clone() + gptq_for_this_block[layer_name] = GPTQ(sub_layers[layer_name], W, self.device) # gptq_for_this_block[layer_name].quantizer = Quantizer() gptq_for_this_block[layer_name].quantizer.configure(weight_config_this_layer) @@ -507,13 +659,13 @@ def tmp(_, inp, out): handles = [] # register handles which add inputs and outputs to gptq object for layer_name in sub_layers: handles.append(sub_layers[layer_name].register_forward_hook(add_batch(layer_name))) - idx = self.cache_key_arguments.pop("i") - for j in range(self.dataloader_len): + batch_num = self.cache_key_arguments.pop("batch_num") + for j in range(batch_num): cache_keyword_batch = self.gather_single_batch_from_dict(self.cache_key_arguments, j) cache_positional_batch = self.gather_single_batch_from_list(self.cache_positional_arguments, j) out = transformer_block(*cache_positional_batch, **cache_keyword_batch) out = self.track_hidden_states(out) - self.cache_key_arguments["i"] = idx + self.cache_key_arguments["batch_num"] = batch_num for h in handles: h.remove() # Step 2.4: everything is prepared, so start quantization! @@ -523,15 +675,81 @@ def tmp(_, inp, out): # ) weight_config_this_layer = self.get_layer_config(self.get_full_layer_name(layer_name, block_idx)) logger.info(f"Quantizing layer {layer_name}") - W = sub_layers[layer_name].weight.data.clone() + if self.use_layer_wise: + from neural_compressor.torch.algorithms.layer_wise import load_value + + full_layer_name = self.get_full_layer_name(layer_name, block_idx) + W = load_value(self.model, full_layer_name + ".weight", model_path) + else: + W = sub_layers[layer_name].weight.data.clone() scale, zp, Q = gptq_for_this_block[layer_name].fasterquant( W, blocksize=weight_config_this_layer["block_size"], percdamp=weight_config_this_layer["percdamp"], groupsize=weight_config_this_layer["group_size"], act_order=weight_config_this_layer["act_order"], + static_groups=weight_config_this_layer["static_groups"], ) - sub_layers[layer_name].weight.data = Q + if self.export_compressed_model: + m = fetch_module(transformer_block, layer_name) + gptq_scale = scale + gptq_zp = None if weight_config_this_layer["sym"] else torch.tensor(zp, dtype=torch.int32) + # recover INT weight + gptq_perm = gptq_for_this_block[layer_name].perm if weight_config_this_layer["act_order"] else None + if weight_config_this_layer["act_order"]: + Q.copy_(Q[:, gptq_perm]) + from .utility import quant_weight_w_scale + + quant_weight_w_scale( + Q, + gptq_scale, + gptq_zp, + weight_config_this_layer["group_size"], + dtype=weight_config_this_layer["dtype"], + ) + # import pdb;pdb.set_trace() + if weight_config_this_layer["act_order"]: + invperm = torch.argsort(gptq_perm) + Q.copy_(Q[:, invperm]) + int_weight = Q.type(torch.int32) # copy_ is not workable for different types. + # replace module + new_module = WeightOnlyLinear( + m.in_features, + m.out_features, + dtype=weight_config_this_layer["dtype"], + bits=weight_config_this_layer["bits"], + group_size=weight_config_this_layer["group_size"], + zp=gptq_zp is not None, + bias=m.bias is not None, + g_idx=gptq_perm is not None, + device=self.device, + ) + new_module.pack(int_weight, gptq_scale, gptq_zp, m.bias) + set_module(transformer_block, layer_name, new_module) + if self.use_layer_wise: + from neural_compressor.torch.algorithms.layer_wise import ( + LWQ_WORKSPACE, + clean_module_weight, + load_value, + set_module_tensor_to_device, + ) + + sub_layer = sub_layers[layer_name] + full_layer_name = self.get_full_layer_name(layer_name, block_idx) + for n, p in sub_layer.named_parameters(): + param_name = full_layer_name + "." + n + if n == "weight": + set_module_tensor_to_device(self.model, param_name, self.device, Q) + else: + value = load_value(self.model, param_name, model_path) + set_module_tensor_to_device(self.model, param_name, self.device, value) + # sub_layer.weight.data = Q + torch.save(sub_layer.state_dict(), LWQ_WORKSPACE + f"/{full_layer_name}.pt") + clean_module_weight(sub_layer) + del Q + gc.collect() + else: + sub_layers[layer_name].weight.data = Q gptq_config[self.get_full_layer_name(layer_name, block_idx)] = {"scale": scale} if not weight_config_this_layer["sym"]: gptq_config[self.get_full_layer_name(layer_name, block_idx)]["zero"] = zp @@ -543,15 +761,18 @@ def tmp(_, inp, out): # Step 2.5: replace output data with quantized weights outs = [] - idx = self.cache_key_arguments.pop("i") - for j in range(self.dataloader_len): + batch_num = self.cache_key_arguments.pop("batch_num") + for j in range(batch_num): cache_keyword_batch = self.gather_single_batch_from_dict(self.cache_key_arguments, j) cache_positional_batch = self.gather_single_batch_from_list(self.cache_positional_arguments, j) out = transformer_block(*cache_positional_batch, **cache_keyword_batch) out = self.track_hidden_states(out) outs.append(out) - self.cache_key_arguments["i"] = idx - self.gptq_related_blocks["transformers"][block_idx] = transformer_block.cpu() + self.cache_key_arguments["batch_num"] = batch_num + if self.use_layer_wise: + self.gptq_related_blocks["transformers"][block_idx] = transformer_block + else: + self.gptq_related_blocks["transformers"][block_idx] = transformer_block.cpu() del gptq_for_this_block torch.cuda.empty_cache() # iteratively replace the input with output, thus layerwise quantization can continue. @@ -616,9 +837,9 @@ def add_batch(self, inp, out): # inp = inp.float() inp = math.sqrt(2 / self.nsamples) * inp.float() # self.H += 2 / self.nsamples * inp.matmul(inp.t()) - self.H += inp.matmul(inp.t()) # H = X*X, which should be a sysm matrix + self.H += inp.matmul(inp.t()) # H = X*X, which should be a sym matrix - def fasterquant(self, W, blocksize=128, percdamp=0.01, groupsize=-1, act_order=False): + def fasterquant(self, W, blocksize=128, percdamp=0.01, groupsize=-1, act_order=False, static_groups=False): # W = self.layer.weight.data.clone() weight_shape, weight_dtype = W.shape, W.data.dtype if isinstance(self.layer, nn.Conv2d): @@ -638,6 +859,17 @@ def fasterquant(self, W, blocksize=128, percdamp=0.01, groupsize=-1, act_order=F H[dead, dead] = 1 W[:, dead] = 0 # such channel makes no contribution to quantization computation + # enable static_groups + # calculate the quantization parameters for original group in advance. + if static_groups: + import copy + + groups = [] + for i in range(0, self.columns, groupsize): + quantizer = copy.deepcopy(self.quantizer) + quantizer.find_params(W[:, i : (i + groupsize)], weight=True) + groups.append(quantizer) + # rearrange considering the diag's value if act_order: perm = torch.argsort(torch.diag(H), descending=True) @@ -674,11 +906,16 @@ def fasterquant(self, W, blocksize=128, percdamp=0.01, groupsize=-1, act_order=F d = Hinv1[i, i] if groupsize != -1: - if (i1 + i) % groupsize == 0: - self.quantizer.find_params(W[:, (i1 + i) : (i1 + i + groupsize)], weight=True) - scale.append(self.quantizer.scale) - zero.append(self.quantizer.zero) - + if not static_groups: + if (i1 + i) % groupsize == 0: + self.quantizer.find_params(W[:, (i1 + i) : (i1 + i + groupsize)], weight=True) + scale.append(self.quantizer.scale) + zero.append(self.quantizer.zero) + else: + idx = i1 + i + if act_order: + idx = perm[idx] + self.quantizer = groups[idx // groupsize] q = self.quantizer.quantize( w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq ).flatten() @@ -724,6 +961,9 @@ def fasterquant(self, W, blocksize=128, percdamp=0.01, groupsize=-1, act_order=F return scale, zero, Q def free(self): + if DEBUG: + self.inp1 = None + self.out1 = None self.H = None self.Losses = None self.Trace = None @@ -740,9 +980,8 @@ def __init__(self, shape=1): def configure(self, weight_config_this_layer, norm=2.4, grid=100, maxshrink=0.8, trits=False): for k, v in weight_config_this_layer.items(): setattr(self, k, v) - self.maxq = torch.tensor(2**self.wbits - 1) + self.maxq = torch.tensor(2**self.bits - 1) self.scheme = "sym" if self.sym else "asym" - self.double_quant = self.double_quant_dtype != "fp32" self.double_quant_scheme = "sym" if self.double_quant_sym else "asym" self.norm = norm self.grid = grid @@ -754,21 +993,20 @@ def find_params(self, x, weight=False): dev = x.device self.maxq = self.maxq.to(dev) # NF4 FP4 - if self.wdtype != "int": + if self.dtype != "int": from .utility import quant_tensor tmp = x.clone() # tmp will be replaced after quant_tensor - _, scale, zero = quant_tensor( tmp, - dtype=self.wdtype, - bits=self.wbits, + dtype=self.dtype, + bits=self.bits, group_size=self.group_size, scheme=self.scheme, quantile=1.0, return_int=True, full_range=False, - double_quant=self.double_quant, + double_quant=self.use_double_quant, double_quant_dtype=self.double_quant_dtype, double_quant_bits=self.double_quant_bits, double_quant_scheme=self.double_quant_scheme, @@ -848,7 +1086,8 @@ def find_params(self, x, weight=False): self.scale = self.scale.reshape(shape) self.zero = self.zero.reshape(shape) - if self.double_quant: + if self.use_double_quant: + # for INT from .utility import quant_tensor orig_scale_shape = self.scale.shape @@ -877,11 +1116,11 @@ def find_params(self, x, weight=False): def quantize(self, x, scale, zero, maxq): """Do quantization.""" - if self.wdtype != "int": + if self.dtype != "int": from .utility import quantize_4bit tmp = x.clone() # tmp will be replaced after quant_tensor - return quantize_4bit(tmp, dtype=self.wdtype, scale=scale) + return quantize_4bit(tmp, dtype=self.dtype, scale=scale) else: if maxq < 0: return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero @@ -892,226 +1131,41 @@ def ready(self): return torch.all(self.scale != 0) -# TODO (Yi) remove it after unifying the algo config parser -def gptq_config_mapping(configs_mapping): - # convert GPTQ_CONFIG to gptq_quantize's weight config - # convert tune_cfg to gptq_quantize's weight config - # for layer_wise quant mode - # TODO (Yi) uncomment it when port layer-wise - # if recipe_cfgs.get("layer_wise_quant", False): - # layer_wise = True - # from .torch_utils.layer_wise_quant.utils import LWQ_WORKSPACE, _get_path, register_weight_hooks - - # os.makedirs(LWQ_WORKSPACE, exist_ok=True) - # # model_path = recipe_cfgs["layer_wise_quant_args"].get("model_path", None) - # model_path = model.path - # assert model_path, "model_path should not be None." - # model_path = _get_path(model_path) - # lwq_handles = register_weight_hooks( - # model, model_path, device=self.device, clean_weight=True, saved_path=LWQ_WORKSPACE - # ) - - weight_config = {} - for (op_name, op_type), op_config in configs_mapping.items(): - if op_config.weight_dtype == "fp32": - continue - else: - weight_config[op_name] = { - "wdtype": op_config.weight_dtype, - "wbits": op_config.weight_bits, - "group_size": op_config.weight_group_size, - "sym": op_config.weight_sym, - "percdamp": op_config.percdamp, - "act_order": op_config.act_order, - "block_size": op_config.block_size, - "mse": op_config.enable_mse_search, - "double_quant_dtype": op_config.double_quant_dtype, - "double_quant_bits": op_config.double_quant_bits, - "double_quant_group_size": op_config.double_quant_group_size, - "double_quant_sym": op_config.double_quant_sym, - } - nsamples = op_config.nsamples - dataloader_len = op_config.dataloader_len - use_max_length = op_config.use_max_length - pad_max_length = op_config.pad_max_length - device = op_config.device - - if use_max_length and op_config.pad_max_length == 2048: - logger.warning( - "You choose to use unified sequence length for calibration, \ - but you have not set length value. Default sequence length is 2048 and this might cause inference error!" - ) - - return weight_config, nsamples, use_max_length, pad_max_length, device, dataloader_len - - -def gptq_quantize(model, configs_mapping, *args, **kwargs): - """Apply gptq.""" +def gptq_quantize( + model, + weight_config={}, + dataloader=None, + nsamples=128, + max_seq_length=2048, + use_max_length=True, + device=None, + export_compressed_model=False, + use_layer_wise=False, + model_path=None, + run_fn=None, + run_args=None, +): + """Run weight-only quantization with.""" # TODO: unify weight_config keys, add docstring, and support default config - weight_config, nsamples, use_max_length, pad_max_length, device, dataloader_len = gptq_config_mapping( - configs_mapping - ) assert isinstance(model, torch.nn.Module), "only support torch module" - # TODO (Yi) disable layer-wise and model_path first - layer_wise = False - model_path = None + if use_layer_wise: + assert model_path is not None, "model_path should not be None when use layer wise mode" + from .gptq import GPTQuantizer gptq_quantizer = GPTQuantizer( model, weight_config, + dataloader, nsamples, - dataloader_len, use_max_length, - pad_max_length, + max_seq_length, device, - layer_wise=layer_wise, - *args, - **kwargs, + export_compressed_model=export_compressed_model, + use_layer_wise=use_layer_wise, + model_path=model_path, + run_fn=run_fn, + run_args=run_args, ) - fp32_modified_model, gptq_config = gptq_quantizer.execute_quantization(model_path=model_path) - logger.info("GPTQ quantization done.") + fp32_modified_model, gptq_config = gptq_quantizer.execute_quantization() + logger.info("GPTQ quantizing done.") return fp32_modified_model, gptq_config - - -class DataloaderPreprocessor: - def __init__(self, dataloader_original, use_max_length=False, pad_max_length=2048, nsamples=128) -> None: - self.dataloader_original = dataloader_original - self.use_max_length = use_max_length - self.pad_max_length = pad_max_length - self.nsamples = nsamples - self.dataloader = [] - self.is_ready = False - - def get_prepared_dataloader(self): - if not self.is_ready: - self.prepare_dataloader() - return self.dataloader - - def prepare_dataloader(self): - if self.use_max_length: - # (Recommend) only take sequence whose length exceeds self.pad_max_length, - # which preserves calibration's tokens are all valid - # This is GPTQ official dataloader implementation - self.obtain_first_n_samples_fulllength() - else: - # general selection, no padding, not GPTQ original implementation. - self.obtain_first_n_samples() - self.is_ready = True - - def obtain_first_n_samples(self, seed=0): - """Get first nsample data as the real calibration dataset.""" - self.dataloader.clear() - random.seed(seed) - for batch in self.dataloader_original: - # process data, depends on its data type. - if len(self.dataloader) == self.nsamples: - logger.info(f"Successfully collect {self.nsamples} calibration samples.") - break - # list, tuple - if isinstance(batch, list) or isinstance(batch, tuple): - if batch[0].shape[-1] > self.pad_max_length: - i = random.randint(0, batch[0].shape[-1] - self.pad_max_length - 1) - j = i + self.pad_max_length - batch_final = [] - for item in batch: - if isinstance(item, torch.Tensor) and item.shape.__len__() == 2: - batch_final.append(item[:, i:j]) - else: - batch_final.append(item) - else: - batch_final = batch[:] - # dict - elif isinstance(batch, dict): - try: - length = batch["input_ids"].shape[-1] - except: - logger.warning("Please make sure your dict'like data contains key of 'input_ids'.") - continue - batch_final = {} - if length > self.pad_max_length: - i = random.randint(0, length - self.pad_max_length - 1) - j = i + self.pad_max_length - # may have to slice every sequence related data - for key in batch.keys(): - if isinstance(batch[key], torch.Tensor): - batch_final[key] = batch[key][:, i:j] # slice on sequence length dim - else: - batch_final[key] = batch[key] - else: - batch_final = batch - # tensor - else: - if batch.shape[-1] > self.pad_max_length: - i = random.randint(0, batch.shape[-1] - self.pad_max_length - 1) - j = i + self.pad_max_length - batch_final = batch[:, i:j] - else: - batch_final = batch - self.dataloader.append(batch_final) - - if len(self.dataloader) < self.nsamples: - logger.warning(f"Try to use {self.nsamples} data, but entire dataset size is {len(self.dataloader)}.") - - def obtain_first_n_samples_fulllength(self, seed=0): - self.dataloader.clear() - random.seed(seed) - unified_length = self.pad_max_length - for batch in self.dataloader_original: - if len(self.dataloader) == self.nsamples: - logger.info(f"Successfully collect {self.nsamples} calibration samples.") - break - # list & tuple, gpt-j-6b mlperf, etc. - if isinstance(batch, list) or isinstance(batch, tuple): - if batch[0].shape[-1] == unified_length: - batch_final = batch[:] - elif batch[0].shape[-1] > unified_length: - i = random.randint(0, batch[0].shape[-1] - unified_length - 1) - j = i + unified_length - batch_final = [] - for item in batch: - if isinstance(item, torch.Tensor) and item.shape.__len__() == 2: - batch_final.append(item[:, i:j]) - else: - batch_final.append(item) - else: - # not match max length, not include in target dataset - continue - # dict - elif isinstance(batch, dict): - try: - length = batch["input_ids"].shape[-1] - except: - logger.warning("Please make sure your dict'like data contains key of 'input_ids'.") - continue - batch_final = {} - if length == self.pad_max_length: - batch_final = batch - elif length > self.pad_max_length: - i = random.randint(0, length - self.pad_max_length - 1) - j = i + self.pad_max_length - # may have to slice every sequence related data - for key in batch.keys(): - if isinstance(batch[key], torch.Tensor): - batch_final[key] = batch[key][:, i:j] # slice on sequence length dim with same position - else: - batch_final[key] = batch[key] - else: - # not match max length, not include in target dataset - continue - # tensor - else: - if batch.shape[-1] == unified_length: - batch_final = batch - elif batch.shape[-1] > unified_length: - i = random.randint(0, batch.shape[-1] - unified_length - 1) - j = i + unified_length - batch_final = batch[:, i:j] - else: - # not match max length, not include in target dataset - continue - self.dataloader.append(batch_final) - if len(self.dataloader) < self.nsamples: # pragma: no cover - logger.warning( - f"Trying to allocate {self.nsamples} data with fixed length {unified_length}, \ - but only {len(self.dataloader)} samples are found. Please use smaller 'self.pad_max_length' value." - ) diff --git a/neural_compressor/torch/algorithms/weight_only/modules.py b/neural_compressor/torch/algorithms/weight_only/modules.py new file mode 100644 index 00000000000..2fb061821c8 --- /dev/null +++ b/neural_compressor/torch/algorithms/weight_only/modules.py @@ -0,0 +1,444 @@ +# +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Torch.nn.Module Class Definition.""" +# Note: Do not import this file unless you have already imported torch, +# since the model classes inherit torch.nn.Module. +import math + +import torch +from torch.autograd import Function +from torch.nn import functional as F + +from neural_compressor.torch.utils import logger + +from .utility import quant_tensor + + +class QDQLayer(torch.nn.Module): + def __init__(self, module, input_scale=None) -> None: + super().__init__() + self.quant = torch.ao.quantization.QuantStub() + self.module = module + self.dequant = torch.ao.quantization.DeQuantStub() + self.input_scale = input_scale + + def forward(self, X): + if self.input_scale is not None: + X = torch.mul(X, self.input_scale) + X = self.quant(X) + X = self.module(X) + X = self.dequant(X) + return X + + +class WeightOnlyLinear(torch.nn.Module): + def __init__( + self, + in_features, + out_features, + dtype="int", + bits=4, + group_size=32, + zp=False, + bias=False, + scale_dtype=torch.float32, + compression_dtype=torch.int32, + compression_dim=1, + g_idx=False, + device="cpu", + use_optimum_format=True, + ): + super().__init__() + self.use_optimum_format = use_optimum_format + self.dtype = dtype + if self.dtype != "int" and "int" in self.dtype: # for nf4, fp4 + bits = self.dtype.lstrip("int") + self.dtype = "int" + if "int" not in self.dtype: # for nf4, fp4 + from neural_compressor.torch.algorithms.weight_only import FLOAT_MAPPING, INT_MAPPING + + self.use_optimum_format = False # optimum_format doesn't suit for symmetric nf4 fp4. + float_list = FLOAT_MAPPING[self.dtype] + int_list = INT_MAPPING[self.dtype] + self.int2float_mapping = {} + for k, v in zip(int_list, float_list): + self.int2float_mapping[k] = v + self.bits = bits + self.device = device + self.in_features = in_features + self.out_features = out_features + self.group_size = group_size if group_size != -1 else in_features + self.compression_dim = compression_dim + assert compression_dtype in [ + torch.int8, + torch.int16, + torch.int32, + torch.int64, + ], "Only support torch.int8|16|32|64 as compressed dtype." + dtype_bits_mapping = {torch.int8: 8, torch.int16: 16, torch.int32: 32, torch.int64: 64} + self.compress_bits = dtype_bits_mapping[compression_dtype] + self.n_pack = self.compress_bits // self.bits + # K is input channel, N is output channel + assert compression_dim in [0, 1], ( + "Only support 0 or 1 as compression dimension, " + "0 is output channel, 1 is input channel." + ) + if self.use_optimum_format: + self.float_type = torch.float16 + self.compression_dtype = torch.int32 + self.register_buffer( + "scales", + torch.zeros( + (math.ceil(in_features / self.group_size), out_features), + dtype=self.float_type, + ).to(device), + ) + self.register_buffer( + "qweight", + torch.zeros( + (math.ceil(in_features / self.n_pack), out_features), + dtype=self.compression_dtype, + ).to(device), + ) + self.register_buffer( + "qzeros", + torch.zeros( + (math.ceil(self.in_features / self.group_size), math.ceil(self.out_features / self.n_pack)), + dtype=self.compression_dtype, + ).to(device), + ) + self.register_buffer("bias", torch.zeros(self.out_features, dtype=self.float_type).to(device)) + else: + self.compression_dtype = compression_dtype + self.float_type = scale_dtype + self.register_buffer( + "scales", + torch.zeros( + (out_features, math.ceil(in_features / self.group_size)), + dtype=self.float_type, + ).to(device), + ) + if compression_dim == 1: + self.register_buffer( + "qweight", + torch.zeros( + (out_features, math.ceil(in_features / self.n_pack)), + dtype=self.compression_dtype, + ).to(device), + ) + if zp: + self.register_buffer( + "qzeros", + torch.zeros( + (self.out_features, math.ceil(self.in_features / self.group_size / self.n_pack)), + dtype=self.compression_dtype, + ).to(device), + ) + else: + self.register_buffer( + "qweight", + torch.zeros( + (math.ceil(out_features / self.n_pack), in_features), + dtype=self.compression_dtype, + ).to(device), + ) + if zp: + self.register_buffer( + "qzeros", + torch.zeros( + (math.ceil(self.out_features / self.n_pack), math.ceil(self.in_features / self.group_size)), + dtype=self.compression_dtype, + ).to(device), + ) + if bias: + self.register_buffer("bias", torch.zeros(self.out_features, dtype=self.float_type).to(device)) + else: + self.bias = None + if g_idx: + self.register_buffer("g_idx", torch.zeros(in_features, dtype=torch.int32).to(device)) + else: + self.g_idx = None + + def pack(self, int_weight, scale, zp, bias, g_idx=None): + if self.use_optimum_format: + self.scales = self.scales.t_().contiguous() + self.qweight = self.qweight.t_().contiguous() + self.qzeros = self.qzeros.t_().contiguous() + int_weight = int_weight.to(self.device) + if self.use_optimum_format and zp is None: + # to avoid overflow + int_weight = int_weight.type(torch.int32) + shift_bias = 2 ** (self.bits - 1) + int_weight += shift_bias + zp = torch.zeros_like(scale, dtype=torch.uint8) + shift_bias + if bias is not None: + assert hasattr(self, "bias"), "bias is not set when initializing." + self.bias = bias.type(self.float_type).to(self.device) + if g_idx is not None: + assert hasattr(self, "g_idx"), "g_idx is not set when initializing." + self.g_idx = g_idx.type(torch.int32).to(self.device) + if self.use_optimum_format: + invperm = torch.argsort(self.g_idx) + self.g_idx = invperm // self.group_size + self.g_idx = self.g_idx.type(torch.int32).to(self.device) + assert scale.shape == self.scales.shape, "Scale shape is mismatched." + self.scales = scale.type(self.float_type).to(self.device) + if not self.use_optimum_format and self.compression_dim == 0: + int_weight = int_weight.t_().contiguous() + self.qweight = self.qweight.t_().contiguous() + origin_shape = int_weight.shape + target_shape = self.qweight.shape + assert origin_shape[0] == target_shape[0], "output channels mismatch, please check." + mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(self.device) + + # pack weight + for j in range(target_shape[1]): + start = self.n_pack * j + end = self.n_pack * (j + 1) + tmp = int_weight[:, start:end].type(self.compression_dtype) + for e in range(tmp.shape[1]): + tmp[:, e] &= mask + tmp[:, e] = tmp[:, e] << (self.bits * e) + self.qweight[:, j] |= tmp[:, e] + if not self.use_optimum_format and self.compression_dim == 0: + self.qweight = self.qweight.t_().contiguous() + + if zp is not None: + zp = zp.to(self.device) + if self.use_optimum_format: + zp -= 1 + if self.use_optimum_format or self.compression_dim == 0: + zp = zp.t_().contiguous() + self.qzeros = self.qzeros.t_().contiguous() + assert hasattr(self, "qzeros"), "zp is not set when initializing." + target_shape = self.qzeros.shape + for j in range(target_shape[1]): + start = self.n_pack * j + end = self.n_pack * (j + 1) + tmp = zp[:, start:end].type(self.compression_dtype) + for e in range(tmp.shape[1]): + tmp[:, e] &= mask + tmp[:, e] = tmp[:, e] << (self.bits * e) + self.qzeros[:, j] |= tmp[:, e] + if self.use_optimum_format or self.compression_dim == 0: + self.qzeros = self.qzeros.t_().contiguous() + if self.use_optimum_format: + self.scales = self.scales.t_().contiguous() + self.qweight = self.qweight.t_().contiguous() + self.qzeros = self.qzeros.t_().contiguous() + + def recover(self): + logger.debug(f"Recovering {self} weight") + scales = self.scales.t_().contiguous() if self.use_optimum_format else self.scales + qweight = self.qweight.t_().contiguous() if self.use_optimum_format else self.qweight + + device = scales.device + fp32_weight = torch.zeros(self.out_features, self.in_features, dtype=self.float_type).to(device) + if self.g_idx is None: + # used for recovering fp32_weight + self.g_idx = torch.tensor([i // self.group_size for i in range(self.in_features)], dtype=torch.int32) + mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(device) + if hasattr(self, "qzeros"): + weight_dtype = torch.uint8 + else: + weight_dtype = torch.int8 + # unpack weight + weight = torch.zeros(self.out_features, self.in_features, dtype=weight_dtype).to(device) + if not self.use_optimum_format and self.compression_dim == 0: + weight = weight.t_().contiguous() + qweight = qweight.t_().contiguous() + origin_shape = weight.shape + target_shape = qweight.shape + for j in range(target_shape[1]): + for e in range(self.n_pack): + index = j * self.n_pack + e + if index >= origin_shape[1]: + continue + tmp = qweight[:, j] + tmp = tmp << (self.compress_bits - self.bits * (e + 1)) + tmp = tmp >> self.compress_bits - self.bits + if weight_dtype == torch.uint8: + tmp &= mask # remove sign bit + weight[:, index] = tmp.type(weight_dtype) + if not self.use_optimum_format and self.compression_dim == 0: + weight = weight.t_().contiguous() + if "int" not in self.dtype: + new_weight = torch.zeros(self.out_features, self.in_features).to(device) + for k, v in self.int2float_mapping.items(): + new_weight += torch.where(weight == k, v, 0) + weight = new_weight + # unpack zero_point + if hasattr(self, "qzeros"): + zp_dtype = self.compression_dtype # to avoid overflow when weight-zp + zp = torch.zeros(scales.shape, dtype=zp_dtype).to(device) + qzeros = self.qzeros.t_().contiguous() if self.use_optimum_format else self.qzeros + if self.use_optimum_format or self.compression_dim == 0: + zp = zp.t_().contiguous() + qzeros = qzeros.t_().contiguous() + origin_shape = zp.shape + target_shape = qzeros.shape + for j in range(target_shape[1]): + for e in range(self.n_pack): + index = j * self.n_pack + e + if index >= origin_shape[1]: + continue + tmp = qzeros[:, j] + tmp = tmp << (self.compress_bits - self.bits * (e + 1)) + tmp = tmp >> self.compress_bits - self.bits + tmp &= mask + zp[:, index] = tmp.type(zp_dtype) + if self.use_optimum_format or self.compression_dim == 0: + zp = zp.t_().contiguous() + if self.use_optimum_format: + # zp -= 1 may cause zp == -1, after recover it becomes 2**self.bits - 1 + zp += 1 + zp = torch.where(zp > (2**self.bits - 1), 0, zp) + # recover fp32 weight with int_weight, scale, and zero_point + for idx in range(self.in_features): + fp32_weight[:, idx] = (weight[:, idx] - zp[:, self.g_idx[idx]]) * scales[:, self.g_idx[idx]] + else: + # recover fp32 weight with int_weight, scale + for idx in range(self.in_features): + fp32_weight[:, idx] = weight[:, idx] * scales[:, self.g_idx[idx]] + return fp32_weight + + def forward(self, input): + if not hasattr(self, "weight"): + weight = self.recover() + device = self.scales.device + if weight.dtype == torch.float16 and device.type == "cpu": + weight = weight.float() + self.bias = self.bias.float() if self.bias is not None else None + if True: # keep reusing self.weight due to recover is too slow. + if not hasattr(self, "weight"): + self.weight = weight + input = input.type(self.weight.dtype) + logger.debug(f"Calculating {self}") + return F.linear(input, self.weight, self.bias) + else: + input = input.type(weight.dtype) + return F.linear(input, weight, self.bias) + + def extra_repr(self) -> str: + tmp_str = "in_features={}, out_features={}, bits={}, group_size={}, bias={}".format( + self.in_features, + self.out_features, + self.bits, + self.group_size, + self.bias is not None, + ) + if self.use_optimum_format: + tmp_str += ", use_optimum_format=True" + return tmp_str + + +class FakeAffineTensorQuantFunction(Function): + """Fake version of affine quantization.""" + + @staticmethod + def forward(ctx, inputs, num_bits=4, group_size=1024, scheme="asym"): + """As it will be only applied on activation with per tensor granularity, broadcast is not needed. + + Args: + ctx: Pytorch convention. + inputs: A Tensor of type float32. + min_range: A float. + max_range: A float. + num_bits: An integer + + Returns: + outputs: A Tensor of type output_dtype + """ + return quant_tensor(inputs, num_bits, group_size, scheme) + + @staticmethod + def backward(ctx, grad_outputs): + """ + Args: + ctx: Pytorch convention. + grad_output: A tensor of gradient of outputs + + Returns: + grad_inputs: A tensor of gradient + """ + return grad_outputs, None, None, None + + +class TEQLinearFakeQuant(torch.nn.Module): + """Wrapper quantization linear.""" + + def __init__(self, orig_layer, alpha=None, num_bits=4, group_size=-1, scheme="asym"): + """A forward hook to linear module + :param orig_layer: the original module + :param alpha: trainable alpha/scale + :param num_bits: quantization level + :param group_size: for fine-grained quantization.""" + super(TEQLinearFakeQuant, self).__init__() + self.orig_layer = orig_layer + self.alpha = alpha + + self.num_bits = num_bits + self.group_size = group_size + self.scheme = scheme + + def forward(self, x): + alpha = torch.clip(self.alpha, 1e-5) + shape_len = len(x.shape) - 1 + shape = (1,) * shape_len + (-1,) + x = x / alpha.view(shape) + weight = self.orig_layer.weight + weight = weight * alpha.unsqueeze(dim=0) + weight_q = FakeAffineTensorQuantFunction().apply(weight, self.num_bits, self.group_size, self.scheme) + return F.linear(x, weight_q, self.orig_layer.bias) + + +class MulLinear(torch.nn.Module): + """Linear wrapper to apply scale to input.""" + + def __init__(self, module, input_scale=None): + """A forward hook to save input max of a module + :param module: the linear module + :param input_scale: scale for input.""" + super().__init__() + if input_scale is None: + input_scale = torch.empty(module.in_features) + self.register_buffer("input_scale", input_scale) + self.add_module("linear", module) + + @property + def weight(self): + return self.linear.weight + + @weight.setter + def weight(self, weight): + self.linear.weight = weight + + def forward(self, X): + X = torch.mul(X, self.input_scale) + X = self.linear(X) + return X + + def _update_linear(self): + # update linear weight with input_scale + scale = self.input_scale.view(1, self.input_scale.shape[0]) + with torch.no_grad(): + self.linear.weight /= scale + + def _recover_linear(self): + # remove mul and reset sq_linear for ipex inference + scale = self.input_scale.view(1, self.input_scale.shape[0]) + with torch.no_grad(): + self.linear.weight *= scale diff --git a/neural_compressor/torch/algorithms/weight_only/rtn.py b/neural_compressor/torch/algorithms/weight_only/rtn.py index e4fba834c92..a47f5c74d41 100644 --- a/neural_compressor/torch/algorithms/weight_only/rtn.py +++ b/neural_compressor/torch/algorithms/weight_only/rtn.py @@ -111,6 +111,9 @@ def rtn_quantize( "double_quant_scheme": weight_config[name]["double_quant_scheme"], "double_quant_group_size": weight_config[name]["double_quant_group_size"], } + if dtype != "int" and "int" in dtype: + bits = int(dtype.lstrip("int")) + dtype = "int" log_msg = ( f"RTN quantization config: bits={bits}, group_size={group_size}, " + f"scheme={scheme}, quantile={quantile}" ) @@ -143,14 +146,14 @@ def rtn_quantize( int_weight = int_weight.t_().contiguous() if group_dim == 0 else int_weight scale = scale.t_().contiguous() if group_dim == 0 else scale zp = zp.t_().contiguous() if group_dim == 0 and zp is not None else zp - from neural_compressor.torch.quantization.modules import WeightOnlyLinear + from .modules import WeightOnlyLinear new_module = WeightOnlyLinear( m.in_features, m.out_features, + dtype=dtype, bits=bits, group_size=group_size, - dtype=dtype, zp=zp is not None, bias=m.bias is not None, use_optimum_format=use_optimum_format, diff --git a/neural_compressor/torch/algorithms/weight_only/utility.py b/neural_compressor/torch/algorithms/weight_only/utility.py index 877aa22e1ca..2f482aa9189 100644 --- a/neural_compressor/torch/algorithms/weight_only/utility.py +++ b/neural_compressor/torch/algorithms/weight_only/utility.py @@ -73,7 +73,7 @@ def quantize_4bit(tensor, quantile=1.0, dtype="nf4", return_int=False, **kwargs) mid_data = [(allow_data[i] + allow_data[i + 1]) / 2 for i in range(len(allow_data) - 1)] q_tensor = torch.zeros_like(tensor) for i in range(len(allow_data)): - data = allow_data_bit[i] if return_int else allow_data[i] + data = allow_data_bit[i] if return_int or "cast_int" in kwargs else allow_data[i] if i == 0: q_tensor += torch.where(tensor <= mid_data[i], data, 0) elif i == len(allow_data) - 1: @@ -295,9 +295,9 @@ def quant_tensor( return weight if quant_scale: weight, scale, zp = q_state - scale_dtype = kwargs.get("double_quant_dtype", "fp32") + scale_dtype = kwargs.get("double_quant_dtype", "int") scale_bits = kwargs.get("double_quant_bits", 8) - scale_scheme = kwargs.get("double_quant_scheme", "sym") + scale_scheme = kwargs.get("double_quant_scheme", "asym") scale_group_size = kwargs.get("double_quant_group_size", 256) scale_return_int = kwargs.get("double_quant_return_int", return_int) orig_scale_shape = scale.shape @@ -308,7 +308,7 @@ def quant_tensor( scale.sub_(scale_mean) scale_scheme = "sym" # process: scale - quant_tensor( + scale = quant_tensor( scale, dtype=scale_dtype, bits=scale_bits, @@ -397,8 +397,8 @@ def search_clip(m, bits=4, group_size=32, scheme="asym", dtype="int", enable_ful return best_clip_ratio -def quant_weight_w_scale(weight, scale, zp, group_size=-1, dtype="int"): - """Quant and dequant tensor with group size. +def quant_weight_w_scale(weight, scale, zp=None, group_size=-1, dtype="int"): + """Quant and dequant tensor with group size. It's an in-place function. Args: weight: input weight @@ -412,32 +412,34 @@ def quant_weight_w_scale(weight, scale, zp, group_size=-1, dtype="int"): """ device = weight.device scale = scale.to(device) - # NF4 FP4 - if dtype in FLOAT_MAPPING.keys(): - int_weight = quantize_4bit( - weight, - quantile=1.0, - dtype=dtype, - return_int=True, - scale=scale, - )[0] - return int_weight - # INT if zp is not None: zp = zp.to(device) + # group_size = -1 if group_size == -1: + if dtype in FLOAT_MAPPING.keys(): # NF4 FP4 + return quantize_4bit(weight, scale=scale, dtype=dtype, return_int=True)[0] return weight.div_(scale).round_() if zp is None else weight.div_(scale).add_(zp).round_() int_weight = torch.zeros(weight.shape).to(device) leng = weight.shape[1] // group_size tail_flag = False if weight.shape[1] % group_size == 0 else True + # group_size != -1 for i in range(leng): - int_weight_tmp = weight[:, i * group_size : (i + 1) * group_size].div_(scale[:, i].unsqueeze(1)) - if zp is not None: - int_weight_tmp.add_(zp[:, i].unsqueeze(1)) - int_weight[:, i * group_size : (i + 1) * group_size].copy_(int_weight_tmp.round_()) + if dtype in FLOAT_MAPPING.keys(): # NF4 FP4 + int_weight_tmp = weight[:, i * group_size : (i + 1) * group_size] + quantize_4bit(int_weight_tmp, scale=scale[:, i].unsqueeze(1), dtype=dtype, return_int=True)[0] + else: + int_weight_tmp = weight[:, i * group_size : (i + 1) * group_size].div_(scale[:, i].unsqueeze(1)) + if zp is not None: + int_weight_tmp.add_(zp[:, i].unsqueeze(1)) + int_weight[:, i * group_size : (i + 1) * group_size].copy_(int_weight_tmp.round_()) + # tail_flag if tail_flag: - int_weight_tmp = weight[:, leng * group_size :].div_(scale[:, -1].unsqueeze(1)) - if zp is not None: - int_weight_tmp.add_(zp[:, -1].unsqueeze(1)) - int_weight[:, leng * group_size :].copy_(int_weight_tmp.round_()) + if dtype in FLOAT_MAPPING.keys(): # NF4 FP4 + int_weight_tmp = weight[:, leng * group_size :] + quantize_4bit(int_weight_tmp, scale=scale[:, -1].unsqueeze(1), dtype=dtype, return_int=True)[0] + else: + int_weight_tmp = weight[:, leng * group_size :].div_(scale[:, -1].unsqueeze(1)) + if zp is not None: + int_weight_tmp.add_(zp[:, -1].unsqueeze(1)) + int_weight[:, leng * group_size :].copy_(int_weight_tmp.round_()) return int_weight diff --git a/neural_compressor/torch/quantization/__init__.py b/neural_compressor/torch/quantization/__init__.py index 4de5a0232e1..06b0a6a058e 100644 --- a/neural_compressor/torch/quantization/__init__.py +++ b/neural_compressor/torch/quantization/__init__.py @@ -25,7 +25,6 @@ get_default_sq_config, ) -# TODO(Yi): move config to config.py from neural_compressor.torch.quantization.autotune import ( autotune, TuningConfig, diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py index 8734f3a2651..fd5c04f0f3c 100644 --- a/neural_compressor/torch/quantization/algorithm_entry.py +++ b/neural_compressor/torch/quantization/algorithm_entry.py @@ -60,8 +60,35 @@ def gptq_entry( model: torch.nn.Module, configs_mapping: Dict[Tuple[str, callable], GPTQConfig], *args, **kwargs ) -> torch.nn.Module: logger.info("Quantize model with the GPTQ algorithm.") + # rebuild weight_config for gptq_quantize function + weight_config = {} + for (op_name, op_type), quant_config in configs_mapping.items(): + weight_config[op_name] = { + "dtype": quant_config.dtype, + "bits": quant_config.bits, + "sym": quant_config.use_sym, + "group_size": quant_config.group_size, + "mse": quant_config.use_mse_search, + "use_double_quant": quant_config.use_double_quant, + "double_quant_dtype": quant_config.double_quant_dtype, + "double_quant_bits": quant_config.double_quant_bits, + "double_quant_sym": quant_config.double_quant_use_sym, + "double_quant_group_size": quant_config.double_quant_group_size, + "act_order": quant_config.act_order, + "percdamp": quant_config.percdamp, + "block_size": quant_config.block_size, + "static_groups": quant_config.static_groups, + } + kwargs.update( + { + "export_compressed_model": quant_config.export_compressed_model, + "use_layer_wise": quant_config.use_layer_wise, + "model_path": quant_config.model_path, + } + ) - model, quantization_perm = gptq_quantize(model=model, configs_mapping=configs_mapping, *args, **kwargs) + logger.warning("lm_head in transformer model is skipped by GPTQ") + model, quantization_perm = gptq_quantize(model=model, weight_config=weight_config, *args, **kwargs) # Assign the gptq config as an attribute of model model._gptq_quantization_perm = quantization_perm return model diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index fa56f9ceb0b..dd03e2f3431 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -65,12 +65,16 @@ class RTNConfig(BaseConfig): params_list = [ "dtype", "bits", - "group_size", "use_sym", + "group_size", + "group_dim", "use_full_range", "use_mse_search", - "use_layer_wise", "export_compressed_model", + # layer wise params + "use_layer_wise", + "model_path", + # double quant "use_double_quant", "double_quant_dtype", "double_quant_bits", @@ -88,13 +92,15 @@ def __init__( group_dim: int = 1, use_full_range: bool = False, use_mse_search: bool = False, - use_layer_wise: bool = False, export_compressed_model: bool = False, + # layer wise + use_layer_wise: bool = False, + model_path: str = "", # double quant use_double_quant: bool = False, double_quant_dtype: str = "int", double_quant_bits: int = 8, # not available when double_quant_dtype is not 'int' - double_quant_use_sym: bool = True, + double_quant_use_sym: bool = False, double_quant_group_size: int = 256, # Tuning space white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST, @@ -102,20 +108,21 @@ def __init__( """Init RTN weight-only quantization config. Args: - dtype (str): Data type for weights, default is "int". - bits (int): Number of bits used to represent weights, default is 4. - use_sym (bool): Indicates whether weights are symmetric, default is True. - group_size (int): Size of weight groups, default is 32. - group_dim (int): Dimension for grouping, default is 1. - use_full_range (bool): Enables full range for activations, default is False. - use_mse_search (bool): Enables mean squared error (MSE) search, default is False. - use_layer_wise (bool): Enables quantize model per layer. Defaults to False. + dtype (str): Data type for weights. Default is "int". + bits (int): Number of bits used to represent weights. Default is 4. + use_sym (bool): Indicates whether weights are symmetric. Default is True. + group_size (int): Size of weight groups. Default is 32. + group_dim (int): Dimension for grouping. Default is 1. + use_full_range (bool): Enables full range for activations. Default is False. + use_mse_search (bool): Enables mean squared error (MSE) search. Default is False. export_compressed_model (bool): Enables return model in int format or not. Defaults to False. - use_double_quant (bool): Enables double quantization, default is False. - double_quant_dtype (str): Data type for double_quant scale, default is "int". - double_quant_bits (int): Number of bits used to represent double_quant scale, default is 4. - double_quant_use_sym (bool): Indicates whether double_quant scale are symmetric, default is True. - double_quant_group_size (int): Size of double_quant groups, default is 32. + use_layer_wise (bool): Enables quantize model per layer. Defaults to False. + model_path (str): Model path that is used to load state_dict per layer. + use_double_quant (bool): Enables double quantization. Default is False. + double_quant_dtype (str): Data type for double_quant scale. Default is "int". + double_quant_bits (int): Number of bits used to represent double_quant scale. Default is 4. + double_quant_use_sym (bool): Indicates whether double_quant scale are symmetric. Default is True. + double_quant_group_size (int): Size of double_quant groups. Default is 32. """ super().__init__(white_list=white_list) self.dtype = dtype @@ -125,15 +132,16 @@ def __init__( self.group_dim = group_dim self.use_full_range = use_full_range self.use_mse_search = use_mse_search - self.use_layer_wise = use_layer_wise self.export_compressed_model = export_compressed_model + self.use_layer_wise = use_layer_wise + self.model_path = model_path # double quant self.use_double_quant = use_double_quant self.double_quant_bits = double_quant_bits self.double_quant_dtype = double_quant_dtype self.double_quant_use_sym = double_quant_use_sym self.double_quant_group_size = double_quant_group_size - self._post_init() # initialize local configuration + self._post_init() # initialize global & local configuration @classmethod def register_supported_configs(cls) -> List[OperatorConfig]: @@ -204,89 +212,107 @@ class GPTQConfig(BaseConfig): name = GPTQ supported_configs: List[OperatorConfig] = [] params_list = [ - "weight_dtype", - "weight_bits", - "weight_group_size", - "weight_sym", - "block_size", - "act_dtype", - "group_dim", - "nsamples", - "dataloader_len", - "percdamp", - "act_order", - "use_max_length", - "pad_max_length", - "enable_mse_search", - "device", - "layer_wise", - "return_int", + "dtype", + "bits", + "use_sym", + "group_size", + "use_mse_search", + "export_compressed_model", + "use_double_quant", "double_quant_dtype", "double_quant_bits", - "double_quant_sym", + "double_quant_use_sym", "double_quant_group_size", + # layer wise params + "use_layer_wise", + "model_path", + # gptq params + "act_order", + "percdamp", + "block_size", + "static_groups", ] def __init__( self, - weight_dtype: str = "int", - weight_bits: int = 4, - weight_group_size: int = 32, - weight_sym: bool = True, - block_size: int = 128, - act_dtype: str = "fp32", - group_dim: int = 1, - nsamples: int = 128, - dataloader_len: int = 10, - percdamp: float = 0.01, - act_order: bool = False, - use_max_length: bool = True, - pad_max_length: int = 2048, - enable_mse_search: bool = False, - device=None, - layer_wise: bool = False, - return_int: bool = False, - double_quant_dtype: str = "fp32", - double_quant_bits: int = 8, - double_quant_sym: bool = True, + dtype: str = "int", + bits: int = 4, + use_sym: bool = True, + group_size: int = 32, + use_mse_search: bool = False, + export_compressed_model: bool = False, + # layer wise + use_layer_wise: bool = False, + model_path: str = "", + # double quant + use_double_quant: bool = False, + double_quant_dtype: str = "int", + double_quant_bits: int = 8, # not available when double_quant_dtype is not 'int' + double_quant_use_sym: bool = False, double_quant_group_size: int = 256, + # gptq params + act_order: bool = False, + percdamp: float = 0.01, + block_size: int = 2048, + static_groups: bool = False, + # Tuning space white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST, ): - """Init GPTQ config. + """Init RTN weight-only quantization config. Args: + dtype (str): Data type for weights. Default is "int". + bits (int): Number of bits used to represent weights. Default is 4. + use_sym (bool): Indicates whether weights are symmetric. Default is True. + group_size (int): Size of weight groups. Default is 32. + use_mse_search (bool): Enables mean squared error (MSE) search. Default is False. + export_compressed_model (bool): Enables return model in int format or not. Defaults to False. + use_layer_wise (bool): Enables quantize model per layer. Defaults to False. + model_path (str): Model path that is used to load state_dict per layer. + use_double_quant (bool): Enables double quantization. Default is False. + double_quant_dtype (str): Data type for double_quant scale. Default is "int". + double_quant_bits (int): Number of bits used to represent double_quant scale. Default is 4. + double_quant_use_sym (bool): Indicates whether double_quant scale are symmetric. Default is True. + double_quant_group_size (int): Size of double_quant groups. Default is 32. + act_order (bool): Whether to sort Hessian's diagonal values to rearrange channel-wise + quantization order. Default is False. + percdamp (float): Percentage of Hessian's diagonal values' average, which will be added to + Hessian's diagonal to increase numerical stability. Default is 0.01. + block_size (int): Execute GPTQ quantization per block, block shape = [C_out, block_size]. + Default is 128. + static_groups (bool): Whether to calculate group wise quantization parameters in advance. + This option mitigate actorder's extra computational requirements. + Default is False. """ super().__init__(white_list=white_list) - self.weight_dtype = weight_dtype - self.weight_bits = weight_bits - self.weight_group_size = weight_group_size - self.weight_sym = weight_sym - self.act_dtype = act_dtype - self.block_size = block_size - self.enable_mse_search = enable_mse_search - self.group_dim = group_dim - self.nsamples = nsamples - # TODO(Yi) detect it auto - self.dataloader_len = dataloader_len - self.percdamp = percdamp - self.act_order = act_order - self.use_max_length = use_max_length - self.pad_max_length = pad_max_length - self.layer_wise = layer_wise - self.device = device - self.return_int = return_int + self.dtype = dtype + self.bits = bits + self.use_sym = use_sym + self.group_size = group_size + self.use_mse_search = use_mse_search + self.export_compressed_model = export_compressed_model + # layer wise + self.use_layer_wise = use_layer_wise + self.model_path = model_path + # double quant + self.use_double_quant = use_double_quant self.double_quant_bits = double_quant_bits self.double_quant_dtype = double_quant_dtype - self.double_quant_sym = double_quant_sym + self.double_quant_use_sym = double_quant_use_sym self.double_quant_group_size = double_quant_group_size - self._post_init() + # gptq + self.act_order = act_order + self.percdamp = percdamp + self.block_size = block_size + self.static_groups = static_groups + self._post_init() # initialize global & local configuration @classmethod def register_supported_configs(cls) -> List[OperatorConfig]: supported_configs = [] # TODO(Yi) linear_gptq_config = GPTQConfig() - operators = [torch.nn.Linear, torch.nn.functional.linear] + operators = [torch.nn.Linear] supported_configs.append(OperatorConfig(config=linear_gptq_config, operators=operators)) cls.supported_configs = supported_configs @@ -304,7 +330,7 @@ def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: @classmethod def get_config_set_for_tuning(cls) -> Union[None, "GPTQConfig", List["GPTQConfig"]]: # TODO fwk owner needs to update it. - return GPTQConfig(weight_bits=[4, 6]) + return GPTQConfig(act_order=[True, False], use_sym=[False, True]) def get_default_gptq_config() -> GPTQConfig: diff --git a/neural_compressor/torch/quantization/modules.py b/neural_compressor/torch/quantization/modules.py index b5b2e4b6e45..d01f1bd781e 100644 --- a/neural_compressor/torch/quantization/modules.py +++ b/neural_compressor/torch/quantization/modules.py @@ -21,24 +21,6 @@ import torch import torch.nn as nn -from packaging.version import Version -from torch.autograd import Function -from torch.nn import functional as F - -from neural_compressor.common import logger -from neural_compressor.torch.algorithms.weight_only import quant_tensor - - -def get_torch_version(): - try: - torch_version = torch.__version__.split("+")[0] - except ValueError as e: # pragma: no cover - assert False, "Got an unknown version of torch: {}".format(e) - version = Version(torch_version) - return version - - -PT_VERSION = get_torch_version().release class Matmul(nn.Module): @@ -63,477 +45,3 @@ def __init__(self): def forward(self, x): return x - - -class QDQLayer(torch.nn.Module): - def __init__(self, module, input_scale=None) -> None: - super().__init__() - self.quant = torch.ao.quantization.QuantStub() - self.module = module - self.dequant = torch.ao.quantization.DeQuantStub() - self.input_scale = input_scale - - def forward(self, X): - if self.input_scale is not None: - X = torch.mul(X, self.input_scale) - X = self.quant(X) - X = self.module(X) - X = self.dequant(X) - return X - - -class SQLinearWrapper(torch.nn.Module): - def __init__(self, module, input_scale, input_minmax, alpha=0.5, dtype=torch.quint8): - super().__init__() - self.register_buffer("input_scale", input_scale) - self.alpha = alpha - self.dtype = dtype - # calculate and only save scale, zero_point to avoid memory usage - self.scale, self.zero_point = self._calculate_qparams(input_scale, input_minmax, dtype) - self.add_module("sq_linear", module) - self._update_sq_linear() - self.ipex = False # a flag used for ipex inference - - @property - def weight(self): - return self.sq_linear.weight - - def forward(self, X): - if self.ipex: - X = self.sq_linear(X) - else: - X = torch.mul(X, self.input_scale) - X = self.sq_linear(X) - return X - - def _calculate_qparams(self, input_scale, input_minmax, dtype=torch.quint8): - # calculate scale and zero_point - if dtype == torch.quint8: - quant_min, quant_max = 0, 255 - min_val = torch.min(input_minmax[0] * input_scale) - max_val = torch.max(input_minmax[1] * input_scale) - # work when min_val bigger than zero. - min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) - max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) - scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) - scale = torch.max(scale, torch.tensor([torch.finfo(torch.float32).eps])) - zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int) - zero_point = torch.clamp(zero_point, quant_min, quant_max) - return scale, zero_point - - def _get_weight_scale(self): - # get weight scale and zero_point - from torch.ao.quantization.observer import default_per_channel_weight_observer - - obs = default_per_channel_weight_observer() - obs(self.sq_linear.weight) - scale, _ = obs.calculate_qparams() - return scale - - def _update_sq_linear(self): - # remove mul and reset sq_linear for ipex inference - scale = self.input_scale.view(1, self.input_scale.shape[0]) - with torch.no_grad(): - self.sq_linear.weight /= scale - - def _recover_sq_linear(self): - # remove mul and reset sq_linear for ipex inference - scale = self.input_scale.view(1, self.input_scale.shape[0]) - with torch.no_grad(): - self.sq_linear.weight *= scale - - -class WeightOnlyLinear(torch.nn.Module): - def __init__( - self, - in_features, - out_features, - bits, - group_size, - dtype="int", - zp=False, - bias=False, - scale_dtype=torch.float32, - compression_dtype=torch.int32, - compression_dim=1, - g_idx=False, - device="cpu", - use_optimum_format=True, - ): - super().__init__() - self.use_optimum_format = use_optimum_format - self.dtype = dtype - if "int" not in self.dtype: # for nf4, fp4 - from neural_compressor.torch.algorithms.weight_only import FLOAT_MAPPING, INT_MAPPING - - self.use_optimum_format = False # optimum_format doesn't suit for symmetric nf4 fp4. - float_list = FLOAT_MAPPING[self.dtype] - int_list = INT_MAPPING[self.dtype] - self.int2float_mapping = {} - for k, v in zip(int_list, float_list): - self.int2float_mapping[k] = v - self.device = device - self.in_features = in_features - self.out_features = out_features - self.bits = bits - self.group_size = group_size if group_size != -1 else in_features - self.compression_dim = compression_dim - assert compression_dtype in [ - torch.int8, - torch.int16, - torch.int32, - torch.int64, - ], "Only support torch.int8|16|32|64 as compressed dtype." - dtype_bits_mapping = {torch.int8: 8, torch.int16: 16, torch.int32: 32, torch.int64: 64} - self.compress_bits = dtype_bits_mapping[compression_dtype] - self.n_pack = self.compress_bits // self.bits - # K is input channel, N is output channel - assert compression_dim in [0, 1], ( - "Only support 0 or 1 as compression dimension, " + "0 is output channel, 1 is input channel." - ) - if self.use_optimum_format: - self.float_type = torch.float16 - self.compression_dtype = torch.int32 - self.register_buffer( - "scales", - torch.zeros( - (math.ceil(in_features / self.group_size), out_features), - dtype=self.float_type, - ).to(device), - ) - self.register_buffer( - "qweight", - torch.zeros( - (math.ceil(in_features / self.n_pack), out_features), - dtype=self.compression_dtype, - ).to(device), - ) - self.register_buffer( - "qzeros", - torch.zeros( - (math.ceil(self.in_features / self.group_size), math.ceil(self.out_features / self.n_pack)), - dtype=self.compression_dtype, - ).to(device), - ) - self.register_buffer("bias", torch.zeros(self.out_features, dtype=self.float_type).to(device)) - else: - self.compression_dtype = compression_dtype - self.float_type = scale_dtype - self.register_buffer( - "scales", - torch.zeros( - (out_features, math.ceil(in_features / self.group_size)), - dtype=self.float_type, - ).to(device), - ) - if compression_dim == 1: - self.register_buffer( - "qweight", - torch.zeros( - (out_features, math.ceil(in_features / self.n_pack)), - dtype=self.compression_dtype, - ).to(device), - ) - if zp: - self.register_buffer( - "qzeros", - torch.zeros( - (self.out_features, math.ceil(self.in_features / self.group_size / self.n_pack)), - dtype=self.compression_dtype, - ).to(device), - ) - else: - self.register_buffer( - "qweight", - torch.zeros( - (math.ceil(out_features / self.n_pack), in_features), - dtype=self.compression_dtype, - ).to(device), - ) - if zp: - self.register_buffer( - "qzeros", - torch.zeros( - (math.ceil(self.out_features / self.n_pack), math.ceil(self.in_features / self.group_size)), - dtype=self.compression_dtype, - ).to(device), - ) - if bias: - self.register_buffer("bias", torch.zeros(self.out_features, dtype=self.float_type).to(device)) - else: - self.bias = None - if g_idx: - self.register_buffer("g_idx", torch.zeros(in_features, dtype=torch.int32).to(device)) - else: - self.g_idx = None - - def pack(self, int_weight, scale, zp, bias, g_idx=None): - if self.use_optimum_format: - self.scales = self.scales.t_().contiguous() - self.qweight = self.qweight.t_().contiguous() - self.qzeros = self.qzeros.t_().contiguous() - int_weight = int_weight.to(self.device) - if self.use_optimum_format and zp is None: - # to avoid overflow - int_weight = int_weight.type(torch.int32) - shift_bias = 2 ** (self.bits - 1) - int_weight += shift_bias - zp = torch.zeros_like(scale, dtype=torch.uint8) + shift_bias - if bias is not None: - assert hasattr(self, "bias"), "bias is not set when initializing." - self.bias = bias.type(self.float_type).to(self.device) - if g_idx is not None: - assert hasattr(self, "g_idx"), "g_idx is not set when initializing." - self.g_idx = g_idx.type(torch.int32).to(self.device) - if self.use_optimum_format: - invperm = torch.argsort(self.g_idx) - self.g_idx = invperm // self.group_size - self.g_idx = self.g_idx.type(torch.int32).to(self.device) - assert scale.shape == self.scales.shape, "Scale shape is mismatched." - self.scales = scale.type(self.float_type).to(self.device) - if not self.use_optimum_format and self.compression_dim == 0: - int_weight = int_weight.t_().contiguous() - self.qweight = self.qweight.t_().contiguous() - origin_shape = int_weight.shape - target_shape = self.qweight.shape - assert origin_shape[0] == target_shape[0], "output channels mismatch, please check." - mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(self.device) - - # pack weight - for j in range(target_shape[1]): - start = self.n_pack * j - end = self.n_pack * (j + 1) - tmp = int_weight[:, start:end].type(self.compression_dtype) - for e in range(tmp.shape[1]): - tmp[:, e] &= mask - tmp[:, e] = tmp[:, e] << (self.bits * e) - self.qweight[:, j] |= tmp[:, e] - if not self.use_optimum_format and self.compression_dim == 0: - self.qweight = self.qweight.t_().contiguous() - - if zp is not None: - zp = zp.to(self.device) - if self.use_optimum_format: - zp -= 1 - if self.use_optimum_format or self.compression_dim == 0: - zp = zp.t_().contiguous() - self.qzeros = self.qzeros.t_().contiguous() - assert hasattr(self, "qzeros"), "zp is not set when initializing." - target_shape = self.qzeros.shape - for j in range(target_shape[1]): - start = self.n_pack * j - end = self.n_pack * (j + 1) - tmp = zp[:, start:end].type(self.compression_dtype) - for e in range(tmp.shape[1]): - tmp[:, e] &= mask - tmp[:, e] = tmp[:, e] << (self.bits * e) - self.qzeros[:, j] |= tmp[:, e] - if self.use_optimum_format or self.compression_dim == 0: - self.qzeros = self.qzeros.t_().contiguous() - if self.use_optimum_format: - self.scales = self.scales.t_().contiguous() - self.qweight = self.qweight.t_().contiguous() - self.qzeros = self.qzeros.t_().contiguous() - - def recover(self): - logger.debug(f"Recovering {self} weight") - scales = self.scales.t_().contiguous() if self.use_optimum_format else self.scales - qweight = self.qweight.t_().contiguous() if self.use_optimum_format else self.qweight - - device = scales.device - fp32_weight = torch.zeros(self.out_features, self.in_features, dtype=self.float_type).to(device) - if self.g_idx is None: - # used for recovering fp32_weight - self.g_idx = torch.tensor([i // self.group_size for i in range(self.in_features)], dtype=torch.int32) - mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(device) - if hasattr(self, "qzeros"): - weight_dtype = torch.uint8 - else: - weight_dtype = torch.int8 - # unpack weight - weight = torch.zeros(self.out_features, self.in_features, dtype=weight_dtype).to(device) - if not self.use_optimum_format and self.compression_dim == 0: - weight = weight.t_().contiguous() - qweight = qweight.t_().contiguous() - origin_shape = weight.shape - target_shape = qweight.shape - for j in range(target_shape[1]): - for e in range(self.n_pack): - index = j * self.n_pack + e - if index >= origin_shape[1]: - continue - tmp = qweight[:, j] - tmp = tmp << (self.compress_bits - self.bits * (e + 1)) - tmp = tmp >> self.compress_bits - self.bits - if weight_dtype == torch.uint8: - tmp &= mask # remove sign bit - weight[:, index] = tmp.type(weight_dtype) - if not self.use_optimum_format and self.compression_dim == 0: - weight = weight.t_().contiguous() - if "int" not in self.dtype: - new_weight = torch.zeros(self.out_features, self.in_features).to(device) - for k, v in self.int2float_mapping.items(): - new_weight += torch.where(weight == k, v, 0) - weight = new_weight - # unpack zero_point - if hasattr(self, "qzeros"): - zp_dtype = self.compression_dtype # to avoid overflow when weight-zp - zp = torch.zeros(scales.shape, dtype=zp_dtype).to(device) - qzeros = self.qzeros.t_().contiguous() if self.use_optimum_format else self.qzeros - if self.use_optimum_format or self.compression_dim == 0: - zp = zp.t_().contiguous() - qzeros = qzeros.t_().contiguous() - origin_shape = zp.shape - target_shape = qzeros.shape - for j in range(target_shape[1]): - for e in range(self.n_pack): - index = j * self.n_pack + e - if index >= origin_shape[1]: - continue - tmp = qzeros[:, j] - tmp = tmp << (self.compress_bits - self.bits * (e + 1)) - tmp = tmp >> self.compress_bits - self.bits - tmp &= mask - zp[:, index] = tmp.type(zp_dtype) - if self.use_optimum_format or self.compression_dim == 0: - zp = zp.t_().contiguous() - if self.use_optimum_format: - # zp -= 1 may cause zp == -1, after recover it becomes 2**self.bits - 1 - zp += 1 - zp = torch.where(zp > (2**self.bits - 1), 0, zp) - # recover fp32 weight with int_weight, scale, and zero_point - for idx in range(self.in_features): - fp32_weight[:, idx] = (weight[:, idx] - zp[:, self.g_idx[idx]]) * scales[:, self.g_idx[idx]] - else: - # recover fp32 weight with int_weight, scale - for idx in range(self.in_features): - fp32_weight[:, idx] = weight[:, idx] * scales[:, self.g_idx[idx]] - return fp32_weight - - def forward(self, input): - if not hasattr(self, "weight"): - weight = self.recover() - device = self.scales.device - if weight.dtype == torch.float16 and device.type == "cpu": - weight = weight.float() - self.bias = self.bias.float() if self.bias is not None else None - if True: # keep reusing self.weight due to recover is too slow. - if not hasattr(self, "weight"): - self.weight = weight - input = input.type(self.weight.dtype) - logger.debug(f"Calculating {self}") - return F.linear(input, self.weight, self.bias) - else: - input = input.type(weight.dtype) - return F.linear(input, weight, self.bias) - - def extra_repr(self) -> str: - tmp_str = "in_features={}, out_features={}, bits={}, group_size={}, bias={}".format( - self.in_features, - self.out_features, - self.bits, - self.group_size, - self.bias is not None, - ) - if self.use_optimum_format: - tmp_str += ", use_optimum_format=True" - return tmp_str - - -class FakeAffineTensorQuantFunction(Function): - """Fake version of affine quantization.""" - - @staticmethod - def forward(ctx, inputs, num_bits=4, group_size=1024, scheme="asym"): - """As it will be only applied on activation with per tensor granularity, broadcast is not needed. - - Args: - ctx: Pytorch convention. - inputs: A Tensor of type float32. - min_range: A float. - max_range: A float. - num_bits: An integer - - Returns: - outputs: A Tensor of type output_dtype - """ - return quant_tensor(inputs, num_bits, group_size, scheme) - - @staticmethod - def backward(ctx, grad_outputs): - """ - Args: - ctx: Pytorch convention. - grad_output: A tensor of gradient of outputs - - Returns: - grad_inputs: A tensor of gradient - """ - return grad_outputs, None, None, None - - -class TEQLinearFakeQuant(torch.nn.Module): - """Wrapper quantization linear.""" - - def __init__(self, orig_layer, alpha=None, num_bits=4, group_size=-1, scheme="asym"): - """A forward hook to linear module - :param orig_layer: the original module - :param alpha: trainable alpha/scale - :param num_bits: quantization level - :param group_size: for fine-grained quantization.""" - super(TEQLinearFakeQuant, self).__init__() - self.orig_layer = orig_layer - self.alpha = alpha - - self.num_bits = num_bits - self.group_size = group_size - self.scheme = scheme - - def forward(self, x): - alpha = torch.clip(self.alpha, 1e-5) - shape_len = len(x.shape) - 1 - shape = (1,) * shape_len + (-1,) - x = x / alpha.view(shape) - weight = self.orig_layer.weight - weight = weight * alpha.unsqueeze(dim=0) - weight_q = FakeAffineTensorQuantFunction().apply(weight, self.num_bits, self.group_size, self.scheme) - return F.linear(x, weight_q, self.orig_layer.bias) - - -class MulLinear(torch.nn.Module): - """Linear wrapper to apply scale to input.""" - - def __init__(self, module, input_scale=None): - """A forward hook to save input max of a module - :param module: the linear module - :param input_scale: scale for input.""" - super().__init__() - if input_scale is None: - input_scale = torch.empty(module.in_features) - self.register_buffer("input_scale", input_scale) - self.add_module("linear", module) - - @property - def weight(self): - return self.linear.weight - - @weight.setter - def weight(self, weight): - self.linear.weight = weight - - def forward(self, X): - X = torch.mul(X, self.input_scale) - X = self.linear(X) - return X - - def _update_linear(self): - # update linear weight with input_scale - scale = self.input_scale.view(1, self.input_scale.shape[0]) - with torch.no_grad(): - self.linear.weight /= scale - - def _recover_linear(self): - # remove mul and reset sq_linear for ipex inference - scale = self.input_scale.view(1, self.input_scale.shape[0]) - with torch.no_grad(): - self.linear.weight *= scale diff --git a/neural_compressor/torch/utils/constants.py b/neural_compressor/torch/utils/constants.py index 2c97889f016..9eb297fdc96 100644 --- a/neural_compressor/torch/utils/constants.py +++ b/neural_compressor/torch/utils/constants.py @@ -27,18 +27,19 @@ "double_quant_use_sym": False, "double_quant_group_size": 256, }, - # TODO(Xin): current implementation is not the same as GGML. - # "GGML_TYPE_Q4_K": { - # "dtype": "int", - # "bits": 4, - # "use_sym": False, - # "group_size": 32, - # "use_double_quant": True, - # "double_quant_bits": 6, - # "double_quant_dtype": "int", - # "double_quant_use_sym": True, - # "double_quant_group_size": 8, - # }, + # TODO: (Xin) current implementation is not the same as GGML. + # GGML is using double_quant_bits to quantize zero points + "GGML_TYPE_Q4_K": { + "dtype": "int", + "bits": 4, + "use_sym": False, + "group_size": 32, + "use_double_quant": True, + "double_quant_bits": 6, + "double_quant_dtype": "int", + "double_quant_use_sym": True, + "double_quant_group_size": 8, + }, } # Setting priorities for algorithms, a higher number indicates a higher priority. diff --git a/test/3x/torch/quantization/weight_only/test_gptq.py b/test/3x/torch/quantization/weight_only/test_gptq.py index 10ea4385c14..064b75e7203 100644 --- a/test/3x/torch/quantization/weight_only/test_gptq.py +++ b/test/3x/torch/quantization/weight_only/test_gptq.py @@ -1,295 +1,183 @@ import copy -import unittest +import pytest import torch +import transformers -from neural_compressor.common import Logger -from neural_compressor.torch.quantization import GPTQConfig, quantize +from neural_compressor.torch.algorithms.weight_only import WeightOnlyLinear +from neural_compressor.torch.quantization import GPTQConfig, get_default_gptq_config, get_default_rtn_config, quantize -logger = Logger().get_logger() +def run_fn(model): + # GPTQ uses ValueError to reduce computation when collecting input data of the first block + # It's special for UTs, no need to add this wrapper in examples. + with pytest.raises(ValueError): + model(torch.tensor([[10, 20, 30]], dtype=torch.long)) + model(torch.tensor([[40, 50, 60]], dtype=torch.long)) -def get_gpt_j(): - import transformers - tiny_gptj = transformers.AutoModelForCausalLM.from_pretrained( - "hf-internal-testing/tiny-random-GPTJForCausalLM", - torchscript=True, - ) - return tiny_gptj - - -class GPTQLLMDataLoader: - def __init__(self, length=512): - self.batch_size = 1 - self.length = length - - def __iter__(self): - for i in range(10): - yield torch.ones([1, self.length], dtype=torch.long) - - -class GPTQLLMDataLoaderList(GPTQLLMDataLoader): - def __iter__(self): - for i in range(10): - yield (torch.ones([1, self.length], dtype=torch.long), torch.ones([1, self.length], dtype=torch.long)) - - -class GPTQLLMDataLoaderDict(GPTQLLMDataLoader): - def __iter__(self): - for i in range(10): - yield { - "input_ids": torch.ones([1, self.length], dtype=torch.long), - "attention_mask": torch.ones([1, self.length], dtype=torch.long), - } - - -from tqdm import tqdm - -from neural_compressor.torch.algorithms.weight_only.gptq import move_input_to_device - - -def run_fn_for_gptq(model, dataloader_for_calibration, *args): - logger.info("Collecting calibration inputs...") - for batch in tqdm(dataloader_for_calibration): - batch = move_input_to_device(batch, device=None) - try: - if isinstance(batch, tuple) or isinstance(batch, list): - model(batch[0]) - elif isinstance(batch, dict): - model(**batch) - else: - model(batch) - except ValueError: - pass - return - - -class TestGPTQ(unittest.TestCase): - @classmethod - def setUpClass(self): - pass +class TestGPTQQuant: + def setup_class(self): + self.tiny_gptj = transformers.AutoModelForCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-GPTJForCausalLM", + ) + self.example_inputs = torch.tensor([[10, 20, 30, 40, 50, 60]], dtype=torch.long) + # record label for comparison + self.label = self.tiny_gptj(self.example_inputs)[0] - @classmethod - def tearDownClass(self): + def teardown_class(self): pass - def setUp(self): - # print the test name - logger.info(f"Running TestGPTQ test: {self.id()}") - - def test_gptq(self): - # Ported from test/adaptor/pytorch_adaptor/test_weight_only_adaptor.py - # TestPytorchWeightOnlyAdaptor.test_GPTQ_fixed_length_quant - - dataloader = GPTQLLMDataLoader() - - # case 1: tensor - model = get_gpt_j() - input = torch.ones([1, 512], dtype=torch.long) - out0 = model(input) - model_1 = copy.deepcopy(model) - device = None - from neural_compressor.torch.algorithms.weight_only.gptq import DataloaderPreprocessor - - dataloaderPreprocessor = DataloaderPreprocessor( - dataloader_original=dataloader, use_max_length=False, pad_max_length=512, nsamples=128 - ) - dataloader_for_calibration = dataloaderPreprocessor.get_prepared_dataloader() - + def test_accuracy_improvement(self): + # test_default_rtn_config + model = copy.deepcopy(self.tiny_gptj) + quant_config = get_default_rtn_config() + model = quantize(model, quant_config, run_fn=run_fn) + rtn_label = model(self.example_inputs)[0] + rtn_atol = (rtn_label - self.label).amax() + # test_default_gptq_config + model = copy.deepcopy(self.tiny_gptj) + quant_config = get_default_gptq_config() + model = quantize(model, quant_config, run_fn=run_fn) + gptq_label = model(self.example_inputs)[0] + gptq_atol = (gptq_label - self.label).amax() + # 0.05 VS 0.08 + assert gptq_atol < rtn_atol, "GPTQ should have lower atol than RTN, please double check." + + @pytest.mark.parametrize( + "bits, use_sym, group_size", + [ + (8, True, 128), + (4, True, 128), + (4, False, 32), + (4, True, 32), + (4, False, -1), + (2, True, 8), + ], + ) + def test_int_params(self, bits, use_sym, group_size): + model = copy.deepcopy(self.tiny_gptj) quant_config = GPTQConfig( - weight_group_size=8, dataloader_len=len(dataloader_for_calibration), pad_max_length=512 - ) - quant_config.set_local("lm_head", GPTQConfig(weight_dtype="fp32")) - logger.info(f"Test GPTQ with config {quant_config}") - q_model = quantize( - model=model_1, quant_config=quant_config, run_fn=run_fn_for_gptq, run_args=dataloader_for_calibration + bits=bits, + use_sym=use_sym, + group_size=group_size, ) - out1 = q_model(input) - self.assertTrue(torch.allclose(out1[0], out0[0], atol=1e-02)) - - # NF4 - model_1 = copy.deepcopy(model) + model = quantize(model, quant_config, run_fn=run_fn) + out = model(self.example_inputs)[0] + assert (out != self.label).all(), "WOQ output should be different with raw output" + if (bits, use_sym, group_size) == (8, True, 128): + assert torch.allclose(out, self.label, atol=0.005), "Accuracy gap atol > 0.005 is unexpected." + if (bits, use_sym, group_size) == [(4, True, 128), (4, True, 32), (4, False, 32), (4, False, -1)]: + assert torch.allclose(out, self.label, atol=0.08), "Accuracy gap atol > 0.08 is unexpected." + if (bits, use_sym, group_size) == [(2, True, 8)]: + assert torch.allclose(out, self.label, atol=0.25), "Accuracy gap atol > 0.25 is unexpected." + + def test_mse_search(self): + # use_mse_search=False + model = copy.deepcopy(self.tiny_gptj) quant_config = GPTQConfig( - weight_dtype="nf4", weight_group_size=8, dataloader_len=len(dataloader_for_calibration), pad_max_length=512 + use_mse_search=False, ) - quant_config.set_local("lm_head", GPTQConfig(weight_dtype="fp32")) - logger.info(f"Test GPTQ with config {quant_config}") - q_model = quantize( - model=model_1, quant_config=quant_config, run_fn=run_fn_for_gptq, run_args=dataloader_for_calibration - ) - out1 = q_model(input) - self.assertTrue(torch.allclose(out1[0], out0[0], atol=1e-02)) - - def test_gptq_advance(self): - # Ported from test/adaptor/pytorch_adaptor/test_weight_only_adaptor.py - # TestPytorchWeightOnlyAdaptor.test_GPTQ_fixed_length_quant - - dataloader = GPTQLLMDataLoader() - model_1 = get_gpt_j() - input = torch.ones([1, 512], dtype=torch.long) - out0 = model_1(input) - - device = None - from neural_compressor.torch.algorithms.weight_only.gptq import DataloaderPreprocessor - - dataloaderPreprocessor = DataloaderPreprocessor( - dataloader_original=dataloader, use_max_length=False, pad_max_length=512, nsamples=128 - ) - dataloader_for_calibration = dataloaderPreprocessor.get_prepared_dataloader() - + model = quantize(model, quant_config, run_fn=run_fn) + out = model(self.example_inputs)[0] + atol_false = (out - self.label).amax() + # use_mse_search=True + model = copy.deepcopy(self.tiny_gptj) quant_config = GPTQConfig( - weight_group_size=8, - dataloader_len=len(dataloader_for_calibration), - act_order=True, - enable_mse_search=True, - pad_max_length=512, + use_mse_search=True, ) - quant_config.set_local("lm_head", GPTQConfig(weight_dtype="fp32")) - logger.info(f"Test GPTQ with config {quant_config}") - q_model = quantize( - model=model_1, quant_config=quant_config, run_fn=run_fn_for_gptq, run_args=dataloader_for_calibration + model = quantize(model, quant_config, run_fn=run_fn) + out = model(self.example_inputs)[0] + atol_true = (out - self.label).amax() + # compare atol, this case is an ideal case. + assert ( + atol_false > atol_true + ), "use_mse_search=True doesn't help accuracy, maybe is reasonable, please double check." + + # def test_layer_wise(self): + # model = copy.deepcopy(self.tiny_gptj) + # quant_config = GPTQConfig( + # use_layer_wise=True, + # ) + # model = quantize(model, quant_config, run_fn=run_fn) + # TODO: (Xin) not implemented + + @pytest.mark.parametrize("dtype", ["int4", "nf4", "fp4"]) + def test_export_compressed_model(self, dtype): + # export_compressed_model = False + model = copy.deepcopy(self.tiny_gptj) + quant_config = GPTQConfig( + dtype=dtype, + export_compressed_model=False, ) - out1 = q_model(input) - self.assertTrue(torch.allclose(out1[0], out0[0], atol=1e-02)) - - def _apply_gptq(self, input, model, quant_config, run_fn, run_args): - logger.info(f"Test GPTQ with config {quant_config}") - out0 = model(input) - q_model = quantize(model=model, quant_config=quant_config, run_fn=run_fn, run_args=run_args) - out1 = q_model(input) - self.assertTrue(torch.allclose(out1[0], out0[0], atol=1e-02)) - - def test_more_gptq(self): - import random - from itertools import product - - # some tests were skipped to accelerate the CI - input = torch.ones([1, 512], dtype=torch.long) - # dataloader - dataloader_collections = [GPTQLLMDataLoader, GPTQLLMDataLoaderList, GPTQLLMDataLoaderDict] - gptq_options = { - "weight_sym": [False, True], - "weight_group_size": [8], - "use_max_length": [False, True], - "pad_max_length": [512], - } - for dataloader_cls in dataloader_collections: - for value in product(*gptq_options.values()): - d = dict(zip(gptq_options.keys(), value)) - quant_config = GPTQConfig(**d) - length = 512 if quant_config.use_max_length else random.randint(1, 1024) - from neural_compressor.torch.algorithms.weight_only.gptq import DataloaderPreprocessor - - dataloaderPreprocessor = DataloaderPreprocessor( - dataloader_original=dataloader_cls(length), - use_max_length=d["use_max_length"], - pad_max_length=d["pad_max_length"], - nsamples=128, - ) - dataloader_for_calibration = dataloaderPreprocessor.get_prepared_dataloader() - quant_config.dataloader_len = len(dataloader_for_calibration) - - self._apply_gptq( - model=get_gpt_j(), - input=input, - quant_config=quant_config, - run_fn=run_fn_for_gptq, - run_args=dataloader_for_calibration, - ) - - def test_gptq_wbits(self): - import copy - import random - - class GPTQLLMDataLoader: - def __init__(self): - self.batch_size = 1 - - def __iter__(self): - for i in range(20): - length = random.randint(1, 1024) - yield torch.ones([1, length], dtype=torch.long) - - dataloader = GPTQLLMDataLoader() - model = copy.deepcopy(get_gpt_j()) - weight_config = { - "transformer.h.0.attn.k_proj": { - "wbits": 4, - "group_size": 128, - "sym": True, - "percdamp": 0.01, - "perchannel": False, - }, - "transformer.h.1.attn.k_proj": { - "wbits": 3, - "group_size": -1, - "sym": False, - "percdamp": 0.01, - "act_order": True, - }, - "transformer.h.2.attn.k_proj": { - "wbits": 3, - "group_size": 32, - "sym": False, - "percdamp": 0.01, - "mse": True, - "act_order": False, - }, - "transformer.h.3.attn.k_proj": { - "wbits": 3, - "group_size": 256, - "sym": False, - "percdamp": 0.01, - "mse": True, - "act_order": False, - }, - } - from neural_compressor.torch.algorithms.weight_only.gptq import DataloaderPreprocessor - - dataloaderPreprocessor = DataloaderPreprocessor( - dataloader_original=dataloader, - use_max_length=True, - pad_max_length=512, - nsamples=128, + model = quantize(model, quant_config, run_fn=run_fn) + out1 = model(self.example_inputs)[0] + # export_compressed_model = True + model = copy.deepcopy(self.tiny_gptj) + quant_config = GPTQConfig( + dtype=dtype, + export_compressed_model=True, ) - preprocessed_dataloader = dataloaderPreprocessor.get_prepared_dataloader() - from neural_compressor.torch.algorithms.weight_only.gptq import GPTQuantizer - - quantizer = GPTQuantizer( - model=model, - weight_config=weight_config, - dataloader_len=13, - use_max_length=True, - pad_max_length=512, - run_fn=run_fn_for_gptq, - run_args=preprocessed_dataloader, + model = quantize(model, quant_config, run_fn=run_fn) + out2 = model(self.example_inputs)[0] + assert isinstance(model.transformer.h[0].attn.k_proj, WeightOnlyLinear), "Exporting compressed model failed." + + # The small gap is caused by FP16 scale in WeightOnlyLinear. + if dtype == "int4": + atol_true = (out1 - out2).amax() + assert ( + atol_true < 0.008 + ), "Exporting compressed model should have the same output as quantized model. Please double check" + else: + assert torch.allclose( + out1, out2 + ), "Exporting compressed model should have the same output as quantized model. Please double check." + + @pytest.mark.parametrize("dtype", ["int4", "nf4", "fp4", "fp4_e2m1_bnb", "fp4_e2m1"]) + def test_dtype_params(self, dtype): + model = copy.deepcopy(self.tiny_gptj) + quant_config = GPTQConfig( + dtype=dtype, ) - quantizer.execute_quantization() - self.assertTrue(isinstance(model, torch.nn.Module)) - self.gptj = get_gpt_j() - - model = copy.deepcopy(self.gptj) - weight_config = {"wbits": 4} - dataloaderPreprocessor = DataloaderPreprocessor( - dataloader_original=dataloader, - use_max_length=False, - pad_max_length=512, - nsamples=128, + model = quantize(model, quant_config, run_fn=run_fn) + out = model(self.example_inputs)[0] + atol = (out - self.label).amax() + assert atol < 0.12, "Accuracy gap atol > 0.12 is unexpected. Please double check." + + @pytest.mark.parametrize("dtype", ["nf4", "int4"]) + @pytest.mark.parametrize("double_quant_bits", [6]) + @pytest.mark.parametrize("double_quant_group_size", [8, 256]) + # TODO: (Xin) to implement + # @pytest.mark.parametrize('export_compressed_model', [False, True]) + def test_double_quant_params(self, dtype, double_quant_bits, double_quant_group_size): + model = copy.deepcopy(self.tiny_gptj) + # double_quant_use_sym = False + quant_config = GPTQConfig( + dtype=dtype, + use_double_quant=True, + double_quant_bits=double_quant_bits, + double_quant_use_sym=False, + double_quant_group_size=double_quant_group_size, ) - quantizer = GPTQuantizer( - model=model, - weight_config=weight_config, - dataloader_len=13, - use_max_length=False, - pad_max_length=512, - run_fn=run_fn_for_gptq, - run_args=preprocessed_dataloader, + model = quantize(model, quant_config, run_fn=run_fn) + out = model(self.example_inputs)[0] + atol_false = (out - self.label).amax() + model = copy.deepcopy(self.tiny_gptj) + # double_quant_use_sym = True + quant_config = GPTQConfig( + dtype=dtype, + use_double_quant=True, + double_quant_bits=double_quant_bits, + double_quant_use_sym=True, + double_quant_group_size=double_quant_group_size, ) - quantizer.execute_quantization() - preprocessed_dataloader = dataloaderPreprocessor.get_prepared_dataloader() - self.assertTrue(isinstance(model, torch.nn.Module)) - - -if __name__ == "__main__": - unittest.main() + model = quantize(model, quant_config, run_fn=run_fn) + out = model(self.example_inputs)[0] + atol_true = (out - self.label).amax() + # compare atol, this case is not an ideal case. + try: + assert ( + atol_false < atol_true + ), "asym for double quant should have smaller atol because scales is bigger than zero, please double check." + except: + assert torch.allclose(atol_false, atol_true, atol=0.008), "atol is very close, double checked the logic." diff --git a/test/3x/torch/quantization/weight_only/test_rtn.py b/test/3x/torch/quantization/weight_only/test_rtn.py index bb64c465587..47367841ce7 100644 --- a/test/3x/torch/quantization/weight_only/test_rtn.py +++ b/test/3x/torch/quantization/weight_only/test_rtn.py @@ -4,13 +4,13 @@ import torch import transformers +from neural_compressor.torch.algorithms.weight_only import WeightOnlyLinear from neural_compressor.torch.quantization import ( RTNConfig, get_default_double_quant_config, get_default_rtn_config, quantize, ) -from neural_compressor.torch.quantization.modules import WeightOnlyLinear class TestRTNQuant: @@ -43,7 +43,6 @@ def teardown_class(self): ], ) def test_int_params(self, bits, use_sym, group_size, group_dim): - print(bits, use_sym, group_size, group_dim) model = copy.deepcopy(self.tiny_gptj) quant_config = RTNConfig( bits=bits, @@ -58,7 +57,7 @@ def test_int_params(self, bits, use_sym, group_size, group_dim): assert torch.allclose(out, self.label, atol=0.01), "Accuracy gap atol > 0.01 is unexpected." if (bits, use_sym, group_size, group_dim) == [(4, True, 128, 0), (4, True, 32, 1)]: assert torch.allclose(out, self.label, atol=0.1), "Accuracy gap atol > 0.1 is unexpected." - if (bits, use_sym, group_size, group_dim) == [(4, False, 32, 0), (4, False, -1, 1), (2, True, 16, 1)]: + if (bits, use_sym, group_size, group_dim) == [(4, False, 32, 0), (4, False, -1, 1), (2, True, 8, 1)]: assert torch.allclose(out, self.label, atol=0.5), "Accuracy gap atol > 0.5 is unexpected." def test_full_range(self): @@ -70,7 +69,7 @@ def test_full_range(self): ) model = quantize(model, quant_config) out = model(self.example_inputs)[0] - atol_false = (out - self.label).max() + atol_false = (out - self.label).amax() # use_full_range=True model = copy.deepcopy(self.tiny_gptj) quant_config = RTNConfig( @@ -79,7 +78,7 @@ def test_full_range(self): ) model = quantize(model, quant_config) out = model(self.example_inputs)[0] - atol_true = (out - self.label).max() + atol_true = (out - self.label).amax() # compare atol, this case is an ideal case. assert ( atol_false > atol_true @@ -93,7 +92,7 @@ def test_mse_search(self): ) model = quantize(model, quant_config) out = model(self.example_inputs)[0] - atol_false = (out - self.label).max() + atol_false = (out - self.label).amax() # use_mse_search=True model = copy.deepcopy(self.tiny_gptj) quant_config = RTNConfig( @@ -101,7 +100,7 @@ def test_mse_search(self): ) model = quantize(model, quant_config) out = model(self.example_inputs)[0] - atol_true = (out - self.label).max() + atol_true = (out - self.label).amax() # compare atol, this case is not an ideal case. try: assert ( @@ -116,7 +115,7 @@ def test_layer_wise(self): use_layer_wise=True, ) model = quantize(model, quant_config) - # TODO(Xin): not implemented + # TODO: (Xin) not implemented @pytest.mark.parametrize("dtype", ["int4", "nf4", "fp4"]) def test_export_compressed_model(self, dtype): @@ -130,7 +129,7 @@ def test_export_compressed_model(self, dtype): model = quantize(model, quant_config) out = model(self.example_inputs)[0] assert isinstance(model.lm_head, WeightOnlyLinear), "Exporting compressed model failed." - atol_true = (out - self.q_label).max() + atol_true = (out - self.q_label).amax() # The small gap is caused by FP16 scale in WeightOnlyLinear. assert ( atol_true < 0.0005 @@ -169,7 +168,7 @@ def test_dtype_params(self, dtype): @pytest.mark.parametrize("dtype", ["int4", "nf4"]) @pytest.mark.parametrize("double_quant_bits", [6]) @pytest.mark.parametrize("double_quant_group_size", [8, 256]) - # TODO(Xin): to implement + # TODO: (Xin) to implement # @pytest.mark.parametrize('export_compressed_model', [False, True]) def test_double_quant_params(self, dtype, double_quant_bits, double_quant_group_size): model = copy.deepcopy(self.tiny_gptj) @@ -183,7 +182,7 @@ def test_double_quant_params(self, dtype, double_quant_bits, double_quant_group_ ) model = quantize(model, quant_config) out = model(self.example_inputs)[0] - atol_false = (out - self.q_label).max() + atol_false = (out - self.q_label).amax() model = copy.deepcopy(self.tiny_gptj) # double_quant_use_sym = True quant_config = RTNConfig( @@ -195,7 +194,7 @@ def test_double_quant_params(self, dtype, double_quant_bits, double_quant_group_ ) model = quantize(model, quant_config) out = model(self.example_inputs)[0] - atol_true = (out - self.q_label).max() + atol_true = (out - self.q_label).amax() # compare atol, this case is an ideal case. assert ( atol_false < atol_true @@ -208,3 +207,17 @@ def test_double_quant_constants(self): model = quantize(model, double_quant_config_dict) out = model(self.example_inputs)[0] assert torch.allclose(out, self.label, atol=0.1), "Accuracy gap atol > 0.1 is unexpected." + # type="BNB_NF4" + model = copy.deepcopy(self.tiny_gptj) + double_quant_config_dict = get_default_double_quant_config(type="BNB_NF4") + model = quantize(model, double_quant_config_dict) + out1 = model(self.example_inputs)[0] + atol_BNB = (out1 - self.label).amax() + assert torch.allclose(out, out1), "Accuracy should be the same, please double check." + # type="BNB_NF4" + model = copy.deepcopy(self.tiny_gptj) + double_quant_config_dict = get_default_double_quant_config(type="GGML_TYPE_Q4_K") + model = quantize(model, double_quant_config_dict) + out1 = model(self.example_inputs)[0] + atol_GGML = (out1 - self.label).amax() + assert atol_BNB < atol_GGML, "atol_BNB should be smaller than atol_GGML due to its asym double_quant." diff --git a/test/3x/torch/test_autotune.py b/test/3x/torch/test_autotune.py index 3dca0ebb612..6f267ab5a96 100644 --- a/test/3x/torch/test_autotune.py +++ b/test/3x/torch/test_autotune.py @@ -5,7 +5,6 @@ import torch import transformers -from neural_compressor.torch.algorithms.weight_only.gptq import DataloaderPreprocessor from neural_compressor.torch.quantization import RTNConfig, TuningConfig, autotune, get_all_config_set from neural_compressor.torch.utils import constants, logger @@ -104,7 +103,7 @@ def __iter__(self): from neural_compressor.torch.algorithms.weight_only.gptq import move_input_to_device -def run_fn_for_gptq(model, dataloader_for_calibration, *args): +def run_fn_for_gptq(model, dataloader_for_calibration, calibration_mode=False): logger.info("Collecting calibration inputs...") for batch in tqdm(dataloader_for_calibration): batch = move_input_to_device(batch, device=None) @@ -117,7 +116,8 @@ def run_fn_for_gptq(model, dataloader_for_calibration, *args): model(batch) except ValueError: pass - return + if not calibration_mode: + print("Accuracy: 1.0") # demo the usage class TestAutoTune(unittest.TestCase): @@ -179,38 +179,31 @@ def eval_perf_fn(model) -> float: @reset_tuning_target def test_autotune_get_config_set_api(self): - dataloader = GPTQLLMDataLoader() - - model = get_gpt_j() - input = torch.ones([1, 512], dtype=torch.long) - - dataloaderPreprocessor = DataloaderPreprocessor( - dataloader_original=dataloader, use_max_length=False, pad_max_length=512, nsamples=128 - ) - dataloader_for_calibration = dataloaderPreprocessor.get_prepared_dataloader() - - def eval_acc_fn(model) -> float: - return 1.0 - - def eval_perf_fn(model) -> float: - return 1.0 - - eval_fns = [ - {"eval_fn": eval_acc_fn, "weight": 0.5, "name": "accuracy"}, - { - "eval_fn": eval_perf_fn, - "weight": 0.5, - }, - ] - custom_tune_config = TuningConfig(config_set=get_all_config_set(), max_trials=4) - best_model = autotune( - model=get_gpt_j(), - tune_config=custom_tune_config, - eval_fns=eval_fns, - run_fn=run_fn_for_gptq, - run_args=dataloader_for_calibration, - ) - self.assertIsNotNone(best_model) + for dataloader in [GPTQLLMDataLoader(), GPTQLLMDataLoaderList(), GPTQLLMDataLoaderDict()]: + model = get_gpt_j() + + def eval_acc_fn(model) -> float: + return 1.0 + + def eval_perf_fn(model) -> float: + return 1.0 + + eval_fns = [ + {"eval_fn": eval_acc_fn, "weight": 0.5, "name": "accuracy"}, + { + "eval_fn": eval_perf_fn, + "weight": 0.5, + }, + ] + custom_tune_config = TuningConfig(config_set=get_all_config_set(), max_trials=4) + best_model = autotune( + model=model, + tune_config=custom_tune_config, + eval_fns=eval_fns, + run_fn=run_fn_for_gptq, + run_args=(dataloader, True), # run_args should be a tuple + ) + self.assertIsNotNone(best_model) @reset_tuning_target def test_autotune_not_eval_func(self): diff --git a/test/3x/torch/test_config.py b/test/3x/torch/test_config.py index 7c25dc36c6d..a14fadf68f4 100644 --- a/test/3x/torch/test_config.py +++ b/test/3x/torch/test_config.py @@ -260,9 +260,9 @@ def test_config_mapping(self): self.assertTrue(configs_mapping[("fc3", torch.nn.Linear)].bits == 5) def test_gptq_config(self): - gptq_config1 = GPTQConfig(weight_bits=8, pad_max_length=512) + gptq_config1 = GPTQConfig(bits=8, act_order=True) quant_config_dict = { - "gptq": {"weight_bits": 8, "pad_max_length": 512}, + "gptq": {"bits": 8, "act_order": True}, } gptq_config2 = GPTQConfig.from_dict(quant_config_dict["gptq"]) self.assertEqual(gptq_config1.to_dict(), gptq_config2.to_dict())