From 57003a5d5b0fd8375eec0c26d95f755e8204c9b8 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 19 Mar 2024 16:45:44 -0700 Subject: [PATCH] Autoquant Summary: Adding autoquantization functionality, using hte do_quant api we can test kernel speeds and pick the best quantization type (or no quantization) for each layer. Test Plan: python test/test.py -k "autoquant" also tested on SAM and SDXL https://github.com/pytorch-labs/segment-anything-fast/pull/114 https://github.com/HDCharles/sdxl-fast/commit/8d9942ab05a552f25f5bfe09da02719ce255467f Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 94089f74edf54f8e2122e91498b25306d322f3ab Pull Request resolved: https://github.com/pytorch-labs/ao/pull/38 --- README.md | 34 +++- __init__.py | 0 test/test.py | 77 +++++++ torchao/__init__.py | 23 ++- torchao/quantization/__init__.py | 3 + torchao/quantization/autoquant.py | 324 ++++++++++++++++++++++++++++++ torchao/quantization/quant_api.py | 65 +++++- torchao/quantization/subclass.py | 2 +- 8 files changed, 514 insertions(+), 14 deletions(-) create mode 100644 __init__.py create mode 100644 torchao/quantization/autoquant.py diff --git a/README.md b/README.md index 8adcd9c24c..1bcaa83877 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ -# torchao: PyTorch Architecture Optimization +# torchao: PyTorch Architecture Optimization **Note: This repository is currently under heavy development - if you have suggestions on the API or use-cases you'd like to be covered, please open an github issue** -The `torchao` package allows you to quantize and prune your models using native PyTorch. +The `torchao` package allows you to quantize and prune your models using native PyTorch. The repo hosts both 1. lower precision [dtypes](./torchao/dtypes) such as nf4, uint4 @@ -38,31 +38,43 @@ pip install -e . Typically quantization algorithms will have different schemes for how the activation and weights are quantized so A16W8 for instance means the activations are quantized to 16 bits wheras the weights are quantized to 8 bits. Trying out different quantization schemes in `torchao` is generally a 1 line change. -### A8W8 Dynamic Quantization +### Autoquantization -```Python +The `autoquant` api can be used to quickly and accurately quantize your model. When used as in the example below, the api first identifies the shapes +of the activations that the different linear layers see, it then benchmarks these shapes across different types of quantized and non-quantized layers in order to pick the fastest one, attempting to take into account fusions where possible. Finally once the best class is found for each layer, it swaps the linear. Currently this api chooses between no quantization, int8 dynamic quantization and int8 weight only quantization for each layer. + +```python import torch -from torchao.quantization import quant_api +import torchao -# Fuse the int8*int8 -> int32 matmul and subsequent mul op avoiding materialization of the int32 intermediary tensor -torch._inductor.config.force_fuse_int_mm_with_mul = True +# inductor settings which improve torch.compile performance for quantized modules +torch._inductor.config.force_fuse_int_mm_with_mul +torch._inductor.config.use_mixed_mm # Plug in your model and example input model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16) input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda') -# convert linear modules to quantized linear modules -quant_api.change_linear_weights_to_int8_dqtensors(model) +# perform autoquantization +torchao.autoquant(model, (input)) # compile the model to improve performance model = torch.compile(model, mode='max-autotune') model(input) ``` + +### A8W8 Dynamic Quantization + +```python +# convert linear modules to quantized linear modules +torchao.change_linear_weights_to_int8_dqtensors(model) +``` + ### A16W8 WeightOnly Quantization ```python -quant_api.change_linear_weights_to_int8_woqtensors(model) +torchao.change_linear_weights_to_int8_woqtensors(model) ``` This technique works best when the torch._inductor.config.use_mixed_mm option is enabled. This avoids dequantizing the weight tensor before the matmul, instead fusing the dequantization into the matmul, thereby avoiding materialization of a large floating point weight tensor. @@ -71,7 +83,7 @@ This technique works best when the torch._inductor.config.use_mixed_mm option is ### A16W4 WeightOnly Quantization ```python -quant_api.change_linear_weights_to_int4_woqtensors(model) +torchao.change_linear_weights_to_int4_woqtensors(model) ``` Note: The quantization error incurred by applying int4 quantization to your model can be fairly significant, so using external techniques like GPTQ may be necessary to obtain a usable model. diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/test.py b/test/test.py index fe3b3ec8a7..248e8ee09d 100644 --- a/test/test.py +++ b/test/test.py @@ -12,6 +12,7 @@ import torch.nn as nn from torch._inductor.utils import run_and_get_code from torch._dynamo import config +import torchao from torch.ao.quantization import MinMaxObserver, QConfigMapping from torchao.quantization.dynamic_quant import ( @@ -54,6 +55,13 @@ _fqn_to_op_to_shape_to_count, LoggingTensorMode, ) +from torchao.quantization.autoquant import ( + AQInt8DynamicallyQuantizedLinearWeight, + AQWeightOnlyQuantizedLinearWeight, + AQWeightOnlyQuantizedLinearWeight2, + AQWeightOnlyQuantizedLinearWeight3 + +) from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx import os @@ -880,6 +888,30 @@ def test_int8_weight_only_quant_subclass(self): Int8WeightOnlyQuantizedLinearWeight.from_float, 40, test_dtype ) + def test_aq_int8_dynamic_quant_subclass(self): + for test_dtype in [torch.float32, torch.float16, torch.bfloat16]: + self._test_lin_weight_subclass_impl( + AQInt8DynamicallyQuantizedLinearWeight.from_float, 35, test_dtype + ) + + def test_aq_int8_weight_only_quant_subclass(self): + for test_dtype in [torch.float32, torch.float16, torch.bfloat16]: + self._test_lin_weight_subclass_impl( + AQWeightOnlyQuantizedLinearWeight.from_float, 35, test_dtype + ) + + def test_aq_int8_weight_only_quant_2_subclass(self): + for test_dtype in [torch.float32, torch.float16, torch.bfloat16]: + self._test_lin_weight_subclass_impl( + AQWeightOnlyQuantizedLinearWeight2.from_float, 35, test_dtype + ) + + def test_aq_int8_weight_only_quant_3_subclass(self): + for test_dtype in [torch.float32, torch.float16, torch.bfloat16]: + self._test_lin_weight_subclass_impl( + AQWeightOnlyQuantizedLinearWeight3.from_float, 35, test_dtype + ) + def test_int4_weight_only_quant_subclass(self): self._test_lin_weight_subclass_impl( Int4WeightOnlyQuantizedLinearWeight.from_float, 10, test_shape=[1, 1024, 8] @@ -1195,6 +1227,51 @@ def test_on_dummy_distilbert(self): print("sqnr_pt_quant", sqnr_pt_quant) self.assertTrue(sqnr_sq >= 8.0) +class TestAutoQuant(unittest.TestCase): + def test_autoquant_one_input(self): + torch._inductor.config.epilogue_fusion = False + torch._inductor.config.use_mixed_mm = True + torch._inductor.config.force_fuse_int_mm_with_mul = True + torch._dynamo.config.automatic_dynamic_shapes = False + + for m,k,n in [ + (1, 1024, 1024), + (64, 1024, 1024), + (2**15, 1024, 1024), + (1, 1024, 4096), + (64, 1024, 4096), + (1, 4096, 1024), + (64, 4096, 1024), + (4096, 4096, 1024), + ]: + example_input = torch.randn(m, k, device="cuda", dtype=torch.bfloat16) + model = torch.nn.Sequential( + torch.nn.ReLU(), + torch.nn.Linear(k,n), + torch.nn.ReLU(), + ).to("cuda").to(torch.bfloat16) + out = model(example_input) + torchao.autoquant(model, example_input) + out2 = model(example_input) + sqnr = SQNR(out, out2) + self.assertTrue(sqnr >= 30) + + def test_autoquant_multi_input(self): + m1, m2, k, n = 1, 8, 1024, 1024 + model = torch.nn.Sequential( + torch.nn.ReLU(), + torch.nn.Linear(k,n), + torch.nn.ReLU(), + ).cuda().to(torch.bfloat16) + example_input = torch.randn(m1, k, device="cuda", dtype=torch.bfloat16) + example_input2 = torch.randn(m2, k, device="cuda", dtype=torch.bfloat16) + torchao.change_linears_to_autoquantizable(model) + out=model(example_input) + model(example_input2) + torchao.change_autoquantizable_to_quantized(model) + out2 = model(example_input) + sqnr = SQNR(out, out2) + self.assertTrue(sqnr >= 30) if __name__ == "__main__": unittest.main() diff --git a/torchao/__init__.py b/torchao/__init__.py index e0497cc3e2..19d7c097a8 100644 --- a/torchao/__init__.py +++ b/torchao/__init__.py @@ -1,5 +1,26 @@ +from torchao.quantization import ( + apply_weight_only_int8_quant, + apply_dynamic_quant, + change_linear_weights_to_int8_dqtensors, + change_linear_weights_to_int8_woqtensors, + change_linear_weights_to_int4_woqtensors, + swap_conv2d_1x1_to_linear, + autoquant, + change_linears_to_autoquantizable, + change_autoquantizable_to_quantized, +) from . import dtypes __all__ = [ - "dtypes" + "apply_weight_only_int8_quant", + "apply_dynamic_quant", + "change_linear_weights_to_int8_dqtensors", + "change_linear_weights_to_int8_woqtensors", + "change_linear_weights_to_int4_woqtensors", + "swap_conv2d_1x1_to_linear" + "safe_int_mm", + "autoquant", + "change_linears_to_autoquantizable", + "change_autoquantizable_to_quantized", + "dtypes" ] diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 80599cb71c..1b421ab8e4 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -25,6 +25,9 @@ "dynamically_quantize_per_channel", "dequantize_per_tensor", "dequantize_per_channel", + "autoquant", + "change_linears_to_autoquantizable", + "change_autoquantizable_to_quantized", "quant_int8_dynamic_linear", "quant_int8_matmul", "quant_int8_dynamic_per_token_linear", diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py new file mode 100644 index 0000000000..f05958c84c --- /dev/null +++ b/torchao/quantization/autoquant.py @@ -0,0 +1,324 @@ +import torch +from .subclass import ( # noqa + Int8DynamicallyQuantizedLinearWeight, + Int8WeightOnlyQuantizedLinearWeight, + QuantizedLinearWeightBase, +) +from torch.utils._python_dispatch import return_and_correct_aliasing +from .quant_primitives import ( + quantize_activation_per_token_absmax, + safe_int_mm, +) +import torch.nn.functional as F +from torch._inductor.utils import do_bench +aten = torch.ops.aten + +AUTOQUANT_CACHE = {} + +def check_cache(cls, shapes_and_dtype): + return AUTOQUANT_CACHE.get((cls,)+shapes_and_dtype, None) + +def update_cache(cls, shapes_and_dtype, res): + AUTOQUANT_CACHE[(cls,)+shapes_and_dtype] = res + +class AutoQuantizableLinearWeight(torch.Tensor): + """ + when run, finds best type of quantization for this tensor and swaps itself with that + """ + @staticmethod + def __new__(cls, weight, qtensor_class_list, *args, mode=["relu", None], **kwargs): + kwargs["device"] = weight.device + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else weight.layout + ) + kwargs["dtype"] = ( + kwargs.get("dtype") if kwargs.get("dtype", False) else weight.dtype + ) + kwargs["requires_grad"] = False + shape = kwargs.pop("shape", weight.shape) + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__(self, weight, qtensor_class_list, *args, mode=["relu", None], **kwargs): + self.weight = weight + self.qtensor_class_list = qtensor_class_list + self.logged_data = {} + self.mode = mode + + def __repr__(self): + return ( + f"{self.__class__.__name__}(data={self.weight}, shape={self.shape}, " + f"device={self.device}, dtype={self.dtype}, qtensor_class_list={self.qtensor_class_list})" + ) + + @staticmethod + def log_shape(act_mat, w_autoquant, bias): + act_mat = act_mat.reshape(-1, act_mat.shape[-1]) + logged_dtype = act_mat.dtype + logged_shapes = (act_mat.shape, w_autoquant.shape, None if bias is None else bias.shape,) + shapes_and_dtype = logged_shapes + (logged_dtype,) + w_autoquant.logged_data[shapes_and_dtype] = 1 + w_autoquant.logged_data.get(shapes_and_dtype, 0) + for q_cls in w_autoquant.qtensor_class_list: + if check_cache(q_cls, shapes_and_dtype) is None: + update_cache(q_cls, shapes_and_dtype, None) + + def tune_autoquant(self, q_cls, shapes_and_dtype, best_time): + act_shape, w_shape, bias_shape, act_dtype = shapes_and_dtype + if check_cache(q_cls, shapes_and_dtype) is None: + with torch.no_grad(): + act_mat = torch.randn(act_shape, dtype=act_dtype, device=self.device) + bias = None if bias_shape is None else torch.randn(bias_shape, dtype=act_dtype, device=self.device) + res = q_cls._autoquant_test(act_mat, self.weight, bias, best_time, self.mode) + update_cache(q_cls, shapes_and_dtype, res) + + def to_quantized(self, error_on_unseen, **kwargs): + if error_on_unseen and self.logged_data == {}: + raise RuntimeError("must run module normally to get shape, dtype info for autoquant") + elif (self.logged_data == {}) and not error_on_unseen: + # default back to non-quantized weight if not seen + self = AQFloatLinearWeight.from_float(self.weight) + return self + + + # only want to do shape+final print a single time if multiple layers + # see/have same shapes so we gate on check_cache being empty for + # at least one of the class/shape combinations. + do_final_print = False + print_once = True + + def count_shapes(self, do_print=True): + differe_shape_count=0 + for shapes_and_dtype, times_seen in self.logged_data.items(): + differe_shape_count += 1 + if do_print: + act_shape, weight_shape, bias_shape, dtype = shapes_and_dtype + print(f"activation_shapes: {act_shape}, times_seen: {times_seen}") + if do_print: + print(f"weight_shape: {weight_shape}, dtype: {dtype}, bias_shape: {bias_shape}") + return differe_shape_count + + # check each class + best_time = torch.inf + best_cls = None + for q_cls in self.qtensor_class_list: + # for each logged shape+dtype, benchmark + cur_time=0 + shape_count = count_shapes(self, do_print=False) + for shapes_and_dtype, times_seen in self.logged_data.items(): + if check_cache(q_cls, shapes_and_dtype) is None: + # only do final print if we have to autotune at least one cls/shape pair + do_final_print=True + + # only print shapes once + if print_once == True: + print_once = False + count_shapes(self, do_print=True) + + time_for_best_shape = check_cache(best_cls, shapes_and_dtype) + time_for_best_shape = torch.inf if time_for_best_shape is None else time_for_best_shape + self.tune_autoquant(q_cls, shapes_and_dtype, time_for_best_shape) + torch._dynamo.reset() + cur_time += check_cache(q_cls, shapes_and_dtype) * times_seen + if shape_count is not None and shape_count > 1: + print(f">total_time: {cur_time:0.3f}ms for {q_cls}, prev_best: {best_time:0.3f}ms") + if best_time >= cur_time: + best_time = cur_time + best_cls = q_cls + # only print if this is the first time seeing some cls+shape combo, + # otherwise we will print the same thing for every layer. + if do_final_print: + print(f"best_cls={best_cls}\n") + # TODO handle random cls args/kwargs? or should they be curried? + self = best_cls.from_float(self.weight) + return self + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.weight), self.qtensor_class_list, dtype=self.dtype, mode=self.mode + ) + + def __tensor_flatten__(self): + return ["weight"], [self.qtensor_class_list, self.mode, self.dtype, self.shape] + + @classmethod + def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None): + weight = tensor_data_dict["weight"] + qtensor_class_list, mode, dtype, shape = tensor_attributes[0] + return cls(weight, qtensor_class_list, mode, shape=shape if outer_size is None else outer_size, dtype=dtype, strides=outer_stride) + + @classmethod + def from_float(cls, weight, qtensor_class_list, **kwargs): + return cls(weight, qtensor_class_list, **kwargs) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + kwargs = {} if kwargs is None else kwargs + + if func is torch.nn.functional.linear: + mat1, w_autoquant, bias = ( + args[0], + args[1], + args[2] if len(args)>2 else None + ) + cls.log_shape(mat1, w_autoquant, bias) + return func(mat1, w_autoquant.weight, bias) + try: + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + except: + print(f"ERR: subclass doesn't implement {func}") + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + if func is aten.detach.default: + return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)) + +def do_autoquant_bench(op, *args, **kwargs): + """ + runs benchmark op(*args, **kwargs) avoiding torch.compile overhead + """ + rep = kwargs.pop("rep", 100) + warmup = kwargs.pop("warmup", 25) + with torch.no_grad(): + torch.cuda.synchronize() + stream = torch.cuda.Stream() + stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(stream): + op(*args, **kwargs) + stream.synchronize() + torch.cuda.current_stream().wait_stream(stream) + torch.cuda.synchronize() + + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream): + op(*args, **kwargs) + res = do_bench(lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median") + return res + +def _is_interpolate_mode(mode): + if isinstance(mode, list) and mode[0]=="interpolate" and len(mode)==2 and isinstance(mode[1], float): + return True + return False + +class AQMixin(): + """ + Mixin to turn normal quantized subclasses into autoquantizable ones + """ + @classmethod + def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): + w_qtensor = cls.from_float(weight) + if _is_interpolate_mode(mode): + q_c_op = torch.compile(cls._quantized_op, mode="max-autotune-no-cudagraphs") + else: + func = lambda a,b,c: F.relu(cls._quantized_op(F.relu(a), b, c)) + q_c_op = torch.compile(func, mode="max-autotune-no-cudagraphs") + res = do_autoquant_bench(q_c_op, act_mat, w_qtensor, bias, warmup=25, rep=100) + if res < best_time*1.1: + res2 = do_autoquant_bench(q_c_op, act_mat, w_qtensor, bias, warmup=25, rep=900) + res=(res2*.9+res*.1) + print(f">>time: {res:0.3f}ms for {cls}, to_beat: {best_time:0.3f}ms ") + return res + +class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, Int8DynamicallyQuantizedLinearWeight): + """ + AutoQuantizable version of Int8DynamicallyQuantizedLinearWeight + """ + @classmethod + def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): + if not _is_interpolate_mode(mode): + return super()._autoquant_test(act_mat, weight, bias, best_time, mode) + + # SAM best is between .8 and 1, SDXL also performs best in this range + INTERPOLATION_CONSTANT = mode[1] + w_qtensor = cls.from_float(weight) + x_vals_int8, x_scales = quantize_activation_per_token_absmax( + act_mat.reshape(-1, act_mat.shape[-1]) + ) + quantized_matmul = ( + lambda x_vals_int8, x_scales, w_vals_int8: + safe_int_mm(x_vals_int8, w_vals_int8) * x_scales + ) + q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune-no-cudagraphs") + with torch.no_grad(): + res_matmul = do_autoquant_bench(q_c_matmul, x_vals_int8, x_scales, w_qtensor.int_data) + print(f">>time: {res_matmul:0.3f}ms for {cls} matmul, to_beat: {best_time:0.3f}ms") + + # if the (much faster) matmul kernel is already beat, don't bother benchmarking full op + if res_matmul>=best_time: + return res_matmul + + # calculate what time full op needs to beat for dynamic quant to be best given INTERPOLATION_CONSTANT + to_beat = best_time + INTERPOLATION_CONSTANT/(1-INTERPOLATION_CONSTANT)*(best_time-res_matmul) + res = super()._autoquant_test(act_mat, weight, bias, to_beat) + max_int_const_win = (best_time-res_matmul)/(res-res_matmul) + res_f = INTERPOLATION_CONSTANT*res+(1-INTERPOLATION_CONSTANT)*res_matmul + print(f">>time: {res_f:0.3f}ms for {cls} interpolated, breakeven constant: {max_int_const_win:0.2f}") + return res_f + +class AQWeightOnlyQuantizedLinearWeight(Int8WeightOnlyQuantizedLinearWeight, AQMixin): + """ + AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight + """ + +class AQWeightOnlyQuantizedLinearWeight2(Int8WeightOnlyQuantizedLinearWeight, AQMixin): + """ + AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that + uses a different kernel + """ + @staticmethod + def _quantized_op(act_mat, w_qtensor, bias): + orig_dtype = act_mat.dtype + orig_shape = act_mat.shape + act_mat = act_mat.reshape(-1, act_mat.shape[-1], 1) + y = (act_mat*w_qtensor.int_data.unsqueeze(0)).sum(dim=-2) + y = y.reshape(*orig_shape[:-1], y.shape[-1]) * w_qtensor.q_scales + if bias is not None: + y += bias + return y.to(orig_dtype) + + @classmethod + def _autoquant_test(cls, act_mat, *args): + # if act_mat has batchsize>2 don't use this kernel + if act_mat.reshape(-1, act_mat.shape[-1]).shape[0]>32: + return torch.inf + return super()._autoquant_test(act_mat, *args) + +class AQWeightOnlyQuantizedLinearWeight3(Int8WeightOnlyQuantizedLinearWeight, AQMixin): + """ + AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that + uses a different kernel + """ + def _quantized_op(act_mat, w_qtensor, bias): + orig_shape = act_mat.shape + y = torch.mm(act_mat.reshape(-1, orig_shape[-1]), w_qtensor.int_data*w_qtensor.q_scales) + y=y.reshape(*orig_shape[:-1], y.shape[-1]) + if bias is not None: + y += bias + return y + +class AQFloatLinearWeight(torch.Tensor, AQMixin): + """ + A class to be used in concert with AutoQuantizableLinearWeight to provide a + default/non-quantized option. Only implements the bare minimum needed to work with the + AutoQuantizableLinearWeight class using the same interfaces that would normally be + used by QTensor subclasses but for a default linear op instead. Result of from_float + is not a tensor subclass, but rather the float tensor. + """ + def __init__(self): + super().__init__() + + @staticmethod + def _quantized_op(act_mat, w_qtensor, bias): + return torch.nn.functional.linear(act_mat, w_qtensor, bias) + + @classmethod + def from_float(cls, weight): + return weight + +DEFAULT_CLASS_LIST = [ + AQFloatLinearWeight, + AQInt8DynamicallyQuantizedLinearWeight, + AQWeightOnlyQuantizedLinearWeight, + AQWeightOnlyQuantizedLinearWeight2, + # AQWeightOnlyQuantizedLinearWeight3, + # TODO this gets picked in places where it makes perf worse, why? +] diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index b815cbd078..1ad10f1820 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -35,6 +35,7 @@ per_token_dynamic_quant, ) from typing import Dict, Tuple +from .autoquant import AutoQuantizableLinearWeight, DEFAULT_CLASS_LIST __all__ = [ "apply_weight_only_int8_quant", @@ -45,6 +46,9 @@ "swap_conv2d_1x1_to_linear", "Quantizer", "TwoStepQuantizer", + "autoquant", + "change_linears_to_autoquantizable", + "change_autoquantizable_to_quantized", ] ############################# Unified Quantization APIs ############################## @@ -140,9 +144,11 @@ def _get_subclass_inserter(cls, **kwargs): # pyre-fixme[53]: Captured variable `cls` is not annotated. # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. + method = kwargs.pop("method", "from_float") def insert_subclass(lin): lin.weight = torch.nn.Parameter( - cls.from_float(lin.weight, **kwargs), requires_grad=False + # cls.from_float(...) + getattr(cls, method)(lin.weight, **kwargs), requires_grad=False ) return lin @@ -206,6 +212,63 @@ def change_linear_weights_to_int4_woqtensors(model, **kwargs): # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. + +def change_linears_to_autoquantizable(model, **kwargs): + """ + Converts all linear weight tensors to the + AutoQuantizableLinearWeight tensor subclass. Expectation is that this is followed + by running the model and then calling change_autoquantizable_to_quantized + """ + filter_fn = kwargs.pop("filter_fn", _is_linear) + kwargs["qtensor_class_list"] = kwargs.get("qtensor_class_list", DEFAULT_CLASS_LIST) + kwargs["mode"] = kwargs.get("mode", ["relu", None]) + _replace_with_custom_fn_if_matches_filter( + model, + _get_subclass_inserter(AutoQuantizableLinearWeight, **kwargs), + filter_fn if filter_fn is not None else _is_linear, + ) + +def change_autoquantizable_to_quantized(model, **kwargs): + """ + Converts AutoQuantizableLinearWeight tensor subclasses + to various quantized/non-quantized tensor subclasses depending + on benchmark results. Expectation is that these modules are + torch.compiled afterwards. + """ + hold = torch._dynamo.config.automatic_dynamic_shapes + torch._dynamo.config.automatic_dynamic_shapes = False + + filter_fn = kwargs.pop( + "filter_fn", + lambda mod, *args: + hasattr(mod, "weight") and isinstance(mod.weight, AutoQuantizableLinearWeight) + ) + error_on_unseen=kwargs.pop("error_on_unseen", True) + _replace_with_custom_fn_if_matches_filter( + model, + _get_subclass_inserter( + AutoQuantizableLinearWeight, method="to_quantized", error_on_unseen=error_on_unseen, **kwargs + ), + filter_fn, + ) + torch._dynamo.config.automatic_dynamic_shapes = hold + torch._dynamo.reset() + +@torch.no_grad() +def autoquant(model, example_input, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=_is_linear, mode=["relu",None], **kwargs): + """ + Runs the model with example_input to record shapes and then compares benchmark performance of the seen shape + across the qtensor subclasses in qtensor_class_list. Determines best performing qtensor subclass for each layer + and applies that type of quantization. + """ + change_linears_to_autoquantizable(model, filter_fn=filter_fn, qtensor_class_list=qtensor_class_list, mode=mode, **kwargs) + if not isinstance(example_input, (tuple, list)): + assert isinstance(example_input, torch.Tensor) + example_input = [example_input] + model(*example_input) + change_autoquantizable_to_quantized(model, **kwargs) + return model + def swap_conv2d_1x1_to_linear(model, filter_fn=None): """ Changes all conv2d 1x1 modules to equivalent linear modules so that they can then be quantized. diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index 90a2c16cf7..4e319a4658 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -323,7 +323,7 @@ def from_float(cls, input_float, qmin=-128, qmax=127): # however the external representation of our tensor will maintain the correct # shape attribute which needs to be tracked directly. int_data = w_int_repr.contiguous().t() - if cls is not Int8DynamicallyQuantizedLinearWeight: + if not issubclass(cls, Int8DynamicallyQuantizedLinearWeight): int_data = int_data.contiguous() return cls( int_data, w_scales, False, input_float.shape, dtype=input_float.dtype