Skip to content

Commit

Permalink
Update on "Autoquant"
Browse files Browse the repository at this point in the history
Summary: Adding autoquantization functionality, using hte do_quant api
we can test kernel speeds and pick the best quantization type (or no
quantization) for each layer.

Test Plan: python test/test.py -k "autoquant"

also tested on SAM and SDXL
pytorch-labs/segment-anything-fast#114
HDCharles/sdxl-fast@8d9942a

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
  • Loading branch information
HDCharles committed Mar 19, 2024
1 parent 97733c2 commit 29214a9
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 59 deletions.
41 changes: 31 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,29 +43,50 @@ The following apis use quantized [tensor subclasses](https://pytorch.org/docs/st

This tensor subclass method of quantization is preferred over older module swap based methods because it doesn't modify the graph and is generally more composable and flexible.

### A8W8 Dynamic Quantization
### Autoquantization

The `change_linear_weights_to_int8_dqtensors` function converts the linear weights in a model to a quantized tensor subclass `Int8DynamicallyQuantizedLinearWeight`. In practice this
converts the floating point linear matmul of the original linear op to a dynamically quantized linear matmul.

Example
The `autoquant` api can be used to quickly and accurately quantize your model. When used as in the example below, the api first identifies the shapes
of the activations that the different linear layers see, it then benchmarks these shapes across different types of quantized and non-quantized layers in order to pick the fastest one, attempting to take into account fusions where possible. Finally once the best class is found for each layer, it swaps the linear. Currently this api chooses between no quantization, int8 dynamic quantization and int8 weight only quantization for each layer.

```
import torch
from torchao.quantization import quant_api
import torchao
# inductor settings which improve torch.compile runtime for quantized modules
torch._inductor.config.force_fuse_int_mm_with_mul
torch._inductor.config.use_mixed_mm
# some user model and example input
model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16)
input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda')
# convert linear modules to quantized linear modules
quant_api.change_linear_weights_to_int8_dqtensors(model)
# perform autoquantization
torchao.autoquant(model, (input))
# compile the model to improve performance
model = torch.compile(model, mode='max-autotune')
model(input)
```


### A8W8 Dynamic Quantization

The `change_linear_weights_to_int8_dqtensors` function converts the linear weights in a model to a quantized tensor subclass `Int8DynamicallyQuantizedLinearWeight`. In practice this
converts the floating point linear matmul of the original linear op to a dynamically quantized linear matmul.

Example

```
# some user model and example input
...
# convert linear modules to quantized linear modules
torchao.change_linear_weights_to_int8_dqtensors(model)
# compile the model to improve performance
...
```

This technique works best when the torch._inductor.config.force_fuse_int_mm_with_mul option is enabled. This allows fusion of the int8*int8 -> int32 matmul and subsequent mul op, thereby avoiding materialization of the int32 intermediary tensor.


Expand All @@ -81,7 +102,7 @@ Example
...
# convert linear modules to quantized linear modules
quant_api.change_linear_weights_to_int8_woqtensors(model)
torchao.change_linear_weights_to_int8_woqtensors(model)
# compile the model to improve performance
...
Expand All @@ -102,7 +123,7 @@ Example
...
# convert linear modules to quantized linear modules
quant_api.change_linear_weights_to_int4_woqtensors(model)
torchao.change_linear_weights_to_int4_woqtensors(model)
# compile the model to improve performance
...
Expand Down
Empty file added __init__.py
Empty file.
24 changes: 24 additions & 0 deletions torchao/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from torchao.quantization import (
apply_weight_only_int8_quant,
apply_dynamic_quant,
change_linear_weights_to_int8_dqtensors,
change_linear_weights_to_int8_woqtensors,
change_linear_weights_to_int4_woqtensors,
swap_conv2d_1x1_to_linear,
autoquant,
change_linears_to_autoquantizable,
change_autoquantizable_to_quantized,
)

__all__ = [
"apply_weight_only_int8_quant",
"apply_dynamic_quant",
"change_linear_weights_to_int8_dqtensors",
"change_linear_weights_to_int8_woqtensors",
"change_linear_weights_to_int4_woqtensors",
"swap_conv2d_1x1_to_linear"
"safe_int_mm",
"autoquant",
"change_linears_to_autoquantizable",
"change_autoquantizable_to_quantized",
]
2 changes: 1 addition & 1 deletion torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"dynamically_quantize_per_channel",
"dequantize_per_tensor",
"dequantize_per_channel",
"do_autoquant",
"autoquant",
"change_linears_to_autoquantizable",
"change_autoquantizable_to_quantized",
"quant_int8_dynamic_linear",
Expand Down
78 changes: 57 additions & 21 deletions torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import torch
import os
from subprocess import check_output
from .subclass import ( # noqa
Int8DynamicallyQuantizedLinearWeight,
Int8WeightOnlyQuantizedLinearWeight,
Expand Down Expand Up @@ -79,26 +77,56 @@ def to_quantized(self, error_on_unseen, **kwargs):
# default back to non-quantized weight if not seen
self = AQFloatLinearWeight.from_float(self.weight)
return self


# only want to do shape+final print a single time if multiple layers
# see/have same shapes so we gate on check_cache being empty for
# at least one of the class/shape combinations.
do_final_print = False
print_once = True

def count_shapes(self, do_print=True):
differe_shape_count=0
for shapes_and_dtype, times_seen in self.logged_data.items():
differe_shape_count += 1
if do_print:
act_shape, weight_shape, bias_shape, dtype = shapes_and_dtype
print(f"activation_shapes: {act_shape}, times_seen: {times_seen}")
if do_print:
print(f"weight_shape: {weight_shape}, dtype: {dtype}, bias_shape: {bias_shape}")
return differe_shape_count

# check each class
best_time = torch.inf
best_cls = None
do_print=False
# check each class
for q_cls in self.qtensor_class_list:
# for each logged shape+dtype, benchmark
cls_res=0
cur_time=0
shape_count = count_shapes(self, do_print=False)
for shapes_and_dtype, times_seen in self.logged_data.items():
if check_cache(q_cls, shapes_and_dtype) is None:
do_print=True
self.tune_autoquant(q_cls, shapes_and_dtype, best_time)
# only do final print if we have to autotune at least one cls/shape pair
do_final_print=True

# only print shapes once
if print_once == True:
print_once = False
count_shapes(self, do_print=True)

time_for_best_shape = check_cache(best_cls, shapes_and_dtype)
time_for_best_shape = torch.inf if time_for_best_shape is None else time_for_best_shape
self.tune_autoquant(q_cls, shapes_and_dtype, time_for_best_shape)
torch._dynamo.reset()
cls_res += check_cache(q_cls, shapes_and_dtype) * times_seen
if best_time >= cls_res:
best_time = cls_res
cur_time += check_cache(q_cls, shapes_and_dtype) * times_seen
if shape_count is not None and shape_count > 1:
print(f">total_time: {cur_time:0.3f}ms for {q_cls}, prev_best: {best_time:0.3f}ms")
if best_time >= cur_time:
best_time = cur_time
best_cls = q_cls
# only print if this is the first time seeing some cls+shape combo,
# otherwise we will print the same thing for every layer.
if do_print:
print(f"for {self.logged_data}, best_cls={best_cls}")
if do_final_print:
print(f"best_cls={best_cls}\n")
# TODO handle random cls args/kwargs? or should they be curried?
self = best_cls.from_float(self.weight)
return self
Expand Down Expand Up @@ -145,21 +173,24 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach))

def do_autoquant_bench(op, *args, **kwargs):
"""
runs benchmark op(*args, **kwargs) avoiding torch.compile overhead
"""
rep = kwargs.pop("rep", 100)
warmup = kwargs.pop("warmup", 25)
with torch.no_grad():
torch.cuda.synchronize()
stream = torch.cuda.Stream()
stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream):
op(*args)
op(*args, **kwargs)
stream.synchronize()
torch.cuda.current_stream().wait_stream(stream)
torch.cuda.synchronize()

graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
op(*args)
op(*args, **kwargs)
res = do_bench(lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median")
return res

Expand All @@ -180,11 +211,11 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
else:
func = lambda a,b,c: F.relu(cls._quantized_op(F.relu(a), b, c))
q_c_op = torch.compile(func, mode="max-autotune-no-cudagraphs")
res = do_autoquant_bench(q_c_op, act_mat, w_qtensor, bias)
res = do_autoquant_bench(q_c_op, act_mat, w_qtensor, bias, warmup=25, rep=100)
if res < best_time*1.1:
res2 = do_autoquant_bench(q_c_op, act_mat, w_qtensor, bias, warmup=25, rep=900)
res=(res2*.9+res*.1)
print(f"time: {res:0.3f}ms for {cls}, to_beat: {best_time:0.3f}ms ")
print(f">>time: {res:0.3f}ms for {cls}, to_beat: {best_time:0.3f}ms ")
return res

class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, Int8DynamicallyQuantizedLinearWeight):
Expand All @@ -196,7 +227,7 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
if not _is_interpolate_mode(mode):
return super()._autoquant_test(act_mat, weight, bias, best_time, mode)

# SAM best is between .8 to 1, SDXL also performs best in this range
# SAM best is between .8 and 1, SDXL also performs best in this range
INTERPOLATION_CONSTANT = mode[1]
w_qtensor = cls.from_float(weight)
x_vals_int8, x_scales = quantize_activation_per_token_absmax(
Expand All @@ -209,7 +240,7 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune-no-cudagraphs")
with torch.no_grad():
res_matmul = do_autoquant_bench(q_c_matmul, x_vals_int8, x_scales, w_qtensor.int_data)
print(f"time: {res_matmul:0.3f}ms for {cls} matmul, to_beat: {best_time:0.3f}ms")
print(f">>time: {res_matmul:0.3f}ms for {cls} matmul, to_beat: {best_time:0.3f}ms")

# if the (much faster) matmul kernel is already beat, don't bother benchmarking full op
if res_matmul>=best_time:
Expand All @@ -220,7 +251,7 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
res = super()._autoquant_test(act_mat, weight, bias, to_beat)
max_int_const_win = (best_time-res_matmul)/(res-res_matmul)
res_f = INTERPOLATION_CONSTANT*res+(1-INTERPOLATION_CONSTANT)*res_matmul
print(f"time: {res_f:0.3f}ms for {cls} interpolated, breakeven constant: {max_int_const_win:0.2f}")
print(f">>time: {res_f:0.3f}ms for {cls} interpolated, breakeven constant: {max_int_const_win:0.2f}")
return res_f

class AQWeightOnlyQuantizedLinearWeight(Int8WeightOnlyQuantizedLinearWeight, AQMixin):
Expand Down Expand Up @@ -252,6 +283,10 @@ def _autoquant_test(cls, act_mat, *args):
return super()._autoquant_test(act_mat, *args)

class AQWeightOnlyQuantizedLinearWeight3(Int8WeightOnlyQuantizedLinearWeight, AQMixin):
"""
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that
uses a different kernel
"""
def _quantized_op(act_mat, w_qtensor, bias):
orig_shape = act_mat.shape
y = torch.mm(act_mat.reshape(-1, orig_shape[-1]), w_qtensor.int_data*w_qtensor.q_scales)
Expand All @@ -265,7 +300,8 @@ class AQFloatLinearWeight(torch.Tensor, AQMixin):
A class to be used in concert with AutoQuantizableLinearWeight to provide a
default/non-quantized option. Only implements the bare minimum needed to work with the
AutoQuantizableLinearWeight class using the same interfaces that would normally be
used by QTensor subclasses but for a default linear op instead.
used by QTensor subclasses but for a default linear op instead. Result of from_float
is not a tensor subclass, but rather the float tensor.
"""
def __init__(self):
super().__init__()
Expand All @@ -284,5 +320,5 @@ def from_float(cls, weight):
AQWeightOnlyQuantizedLinearWeight,
AQWeightOnlyQuantizedLinearWeight2,
# AQWeightOnlyQuantizedLinearWeight3,
# 3rd version gets picked in situations where it is slower for the interpolation mode
# TODO this gets picked in places where it makes perf worse, why?
]
13 changes: 7 additions & 6 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
"change_linear_weights_to_int8_woqtensors",
"change_linear_weights_to_int4_woqtensors",
"swap_conv2d_1x1_to_linear",
"do_autoquant",
"autoquant",
"change_linears_to_autoquantizable",
"change_autoquantizable_to_quantized",
]
Expand Down Expand Up @@ -182,6 +182,9 @@ def change_autoquantizable_to_quantized(model, **kwargs):
on benchmark results. Expectation is that these modules are
torch.compiled afterwards.
"""
hold = torch._dynamo.config.automatic_dynamic_shapes
torch._dynamo.config.automatic_dynamic_shapes = False

filter_fn = kwargs.pop(
"filter_fn",
lambda mod, *args:
Expand All @@ -195,24 +198,22 @@ def change_autoquantizable_to_quantized(model, **kwargs):
),
filter_fn,
)
torch._dynamo.config.automatic_dynamic_shapes = hold
torch._dynamo.reset()

@torch.no_grad()
def do_autoquant(model, example_input, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=_is_linear, mode=["relu",None], **kwargs):
def autoquant(model, example_input, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=_is_linear, mode=["relu",None], **kwargs):
"""
Runs the model with example_input to record shapes and then compares benchmark performance of the seen shape
across the qtensor subclasses in qtensor_class_list. Determines best performing qtensor subclass for each layer
and applies that type of quantization.
"""
hold = torch._dynamo.config.automatic_dynamic_shapes
torch._dynamo.config.automatic_dynamic_shapes = False
change_linears_to_autoquantizable(model, filter_fn=filter_fn, qtensor_class_list=qtensor_class_list, mode=mode, **kwargs)
if not isinstance(example_input, (tuple, list)):
assert isinstance(example_input, torch.Tensor)
example_input = [example_input]
model(*example_input)
change_autoquantizable_to_quantized(model, **kwargs)
torch._dynamo.config.automatic_dynamic_shapes = hold
torch._dynamo.reset()
return model

def swap_conv2d_1x1_to_linear(model, filter_fn=None):
Expand Down
21 changes: 0 additions & 21 deletions torchao/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import torch
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils.benchmark import Timer

__all__ = [
"find_multiple",
Expand Down Expand Up @@ -87,23 +86,3 @@ def get_model_size_in_bytes(model):
for b in model.buffers():
s += b.nelement() * b.element_size()
return s


def benchmark(f, *args, **kwargs):
if "best_time" in kwargs:
best_time = kwargs.pop("best_time")
else:
best_time = torch.inf
t0 = Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)

# warmup
t0.timeit(10)
res=t0.adaptive_autorange(min_run_time=.1)
# run more if median vs median minus iqr (interpolated based on number of runs left) is lower than best_time,
# stop if good res.iqr/res.median or have 20 samples
while res.median-res.iqr+res.iqr*len(res.times)/20 < best_time * 1e-3 and not (res.iqr/res.median<.02 or len(res.times)>=20):
res2 = t0.adaptive_autorange(min_run_time=.5)
res=res.merge([res2, res])[0]
return res.median * 1e3

0 comments on commit 29214a9

Please sign in to comment.