Skip to content

Commit

Permalink
support TEQ layerwise config. (#1120)
Browse files Browse the repository at this point in the history
* support TEQ layerwise config.

* fix bug of folding=false.

* fix ut.

* fix ut.

* fix ut.

* fix comments.

* weight config can exclude specific layer.

* fix coverage issue.
  • Loading branch information
lkk12014402 authored Jul 28, 2023
1 parent 291c4fa commit 9ff7f01
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 136 deletions.
88 changes: 69 additions & 19 deletions neural_compressor/adaptor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4589,39 +4589,89 @@ def teq_quantize(self, model, tune_cfg, dataloader, calib_func):
logger.debug("quantizing with the TEQ algorithm")
from .torch_utils.weight_only import teq_quantize
# get example inputs if not provided.
if self.example_inputs is None:
if self.example_inputs is None: # pragma: no cover
if dataloader is None:
assert False, "Please provide dataloader or example_inputs for TEQ algorithm."
try:
for idx, (input, label) in enumerate(dataloader):
self.example_inputs = input
for idx, (x, label) in enumerate(dataloader):
self.example_inputs = x.to(model.device)
break
except:
for idx, input in enumerate(dataloader):
self.example_inputs = input
for idx, x in enumerate(dataloader):
self.example_inputs = x.to(model.device)
break

if 'teq_args' in self.recipes:
wbits = self.recipes.get('wbits', 4)
group_size = self.recipes.get('group_size', 128)
sym = self.recipes.get('scheme', False)
folding = self.recipes.get('folding', True)
folding = True
if 'teq_args' in self.recipes: # pragma: no cover
folding = self.recipes['teq_args'].get('folding', True)

supported_layers = ['Linear']
if folding: # pragma: no cover
from .torch_utils.smooth_quant import GraphTrace
tg = GraphTrace()
absorb_to_layer, _ = tg.get_absorb_to_layer(model, self.example_inputs, supported_layers)
if absorb_to_layer is None or absorb_to_layer == {}:
logger.warning('No absorb layer is detected, skip TEQ algorithm')
return model
else: # pragma: no cover
absorb_to_layer = {}
for name, module in model.named_modules():
for op_type in supported_layers:
if op_type == str(module.__class__.__name__):
absorb_to_layer[name] = [name]

weight_config = {
'wbits': wbits,
'group_size': group_size,
'sym': sym,
'folding': folding
}
quantizer = teq_quantize(
# got flipped dict from absorb_to_layer dict
flipped_dict = {}
for k, v in absorb_to_layer.items():
for m in v:
flipped_dict[m] = {'absorb_layer': k}

# check tune_cfg to skip layers without TEQ config
weight_config = {}
skipped_op_name_set = set()
for key, config in tune_cfg['op'].items():
op_name, op_type = key
if config['weight']['dtype'] == 'fp32': # pragma: no cover
if op_name in flipped_dict:
absorb_to_layer.pop(flipped_dict[op_name]['absorb_layer'])
continue
else:
weight_config[op_name] = {}
weight_config[op_name]['bits'] = config['weight']['bits']
weight_config[op_name]['group_size'] = config['weight']['group_size']
weight_config[op_name]['scheme'] = config['weight']['scheme']
if op_name in flipped_dict:
algorithm = config['weight']['algorithm']
if algorithm != 'TEQ':
absorb_to_layer.pop(weight_config[op_name]['absorb_layer'])
else:
skipped_op_name_set.add(op_name)
if skipped_op_name_set: # pragma: no cover
logger.info("{} is skipped by TEQ algorithm".format(skipped_op_name_set))

# collect TEQ config from tune_cfg for quantization.
if len(absorb_to_layer) == 0: # pragma: no cover
logger.warning('No absorb layer needs TEQ algorithim, skip it')
else: # pragma: no cover
logger.debug("**absorb layer**: **absorbed layers**")
for k, v in absorb_to_layer.items():
logger.debug(f"{k}: {v}")

logger.info("Absorbed layers with the same absorb layer use the same config")

extra_config = {"folding": folding}

model = teq_quantize(
model,
weight_config,
absorb_to_layer,
extra_config,
dataloader,
example_inputs=self.example_inputs,
calib_func=calib_func
)
return quantizer.model

return model
def awq_quantize(self, model, tune_cfg, dataloader, calib_func):
logger.debug("quantizing with the AWQ algorithm")
from .torch_utils.weight_only import awq_quantize
Expand Down
4 changes: 2 additions & 2 deletions neural_compressor/adaptor/pytorch_cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@
# group_size=-1 means per-channel, others means per-group
'group_size': [32, -1, 1, 4, 8, 16, 64, 128, 256, 512, 1024], # [1-inf], # 32
'scheme': ['sym', 'asym'], # sym, no ZP
'algorithm': ['RTN', 'AWQ', 'GPTQ'], # RTN, [RTN, GPTQ, AWQ,] RTN+AWQ+TEQ order
'algorithm': ['RTN', 'AWQ', 'GPTQ', 'TEQ'], # RTN, [RTN, GPTQ, AWQ,] RTN+AWQ+TEQ order
},
'activation': {
'dtype': ['fp32'],
Expand Down Expand Up @@ -445,4 +445,4 @@
'dynamic': *cap_dynamic_s8_1_6,
'quant_aware': *cap_s8_1_6
}
uint8: *cap_s8_1_6
uint8: *cap_s8_1_6
12 changes: 7 additions & 5 deletions neural_compressor/adaptor/torch_utils/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ class FakeAffineTensorQuantFunction(Function):
"""

@staticmethod
def forward(ctx, inputs, num_bits=4, group_size=1024):
def forward(ctx, inputs, num_bits=4, group_size=1024, scheme="asym"):
"""
As it will be only applied on activation with per tensor granularity, broadcast is not needed.
Expand All @@ -379,7 +379,7 @@ def forward(ctx, inputs, num_bits=4, group_size=1024):
Returns:
outputs: A Tensor of type output_dtype
"""
return quant_weight(inputs, num_bits, group_size)
return quant_weight(inputs, num_bits, group_size, scheme)

@staticmethod
def backward(ctx, grad_outputs):
Expand All @@ -391,15 +391,15 @@ def backward(ctx, grad_outputs):
Returns:
grad_inputs: A tensor of gradient
"""
return grad_outputs, None, None
return grad_outputs, None, None, None


class TEQLinearFakeQuant(torch.nn.Module):
"""
wrapper quantization linear
"""

def __init__(self, orig_layer, alpha=None, num_bits=4, group_size=-1):
def __init__(self, orig_layer, alpha=None, num_bits=4, group_size=-1, scheme="asym"):
"""
A forward hook to linear module
:param orig_layer: the original module
Expand All @@ -413,6 +413,7 @@ def __init__(self, orig_layer, alpha=None, num_bits=4, group_size=-1):

self.num_bits = num_bits
self.group_size = group_size
self.scheme = scheme

def forward(self, x):
alpha = torch.clip(self.alpha, 1e-5)
Expand All @@ -421,7 +422,8 @@ def forward(self, x):
x = x / alpha.view(shape)
weight = self.orig_layer.weight
weight = weight * alpha.unsqueeze(dim=0)
weight_q = FakeAffineTensorQuantFunction().apply(weight, self.num_bits, self.group_size)
weight_q = FakeAffineTensorQuantFunction().apply(weight, self.num_bits,
self.group_size, self.scheme)
return F.linear(x, weight_q, self.orig_layer.bias)


Expand Down
Loading

0 comments on commit 9ff7f01

Please sign in to comment.