Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPFQ #666

Merged
merged 12 commits into from
Sep 27, 2023
Merged

GPFQ #666

Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/brevitas/graph/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

from abc import ABC
from copy import deepcopy
from functools import partial
import sys

Expand Down Expand Up @@ -279,6 +280,7 @@ def forward_hook_wbiol(self, module, inp, output, name):
# Compute float reference
self.disable_act_quantization(module, is_training=False)
self.disable_param_quantization(module, is_training=False)

out_float = module.forward(*inp) # Required to avoid infinite recursion
self.collect_float_mean(module, out_float, name)
self.enable_act_quantization(module, is_training=False)
Expand Down
227 changes: 227 additions & 0 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from copy import deepcopy
from typing import List, Optional

import numpy as np
import torch
import unfoldNd

from brevitas.graph.gpxq import GPxQ
from brevitas.graph.gpxq import gpxq_mode
from brevitas.graph.gpxq import StopFwdException
from brevitas.graph.gpxq import SUPPORTED_CONV_OP
import brevitas.nn as qnn


class gpfq_mode(gpxq_mode):
"""
Apply GPFQ algorithm.

Args:
model (Module): The model to quantize with GPFQ
inplace (bool): Wheter to apply GPFQ inplace or perform a deepcopy. Default: True
use_quant_activations (bool): Wheter to leave quantize activations enabled while performing
GPFQ. Default: False

Example:
>>> with torch.no_grad():
>>> with gpfq_mode(model) as gpfq:
>>> gpfq_model = gpfq.model
>>> for i in tqdm(range(gpfq.num_layers)):
>>> for img, t in calib_loader:
>>> img = img.cuda()
>>> gpfq_model(img)
>>> gpfq.update()
"""

def __init__(
self,
model,
group_of_parallel_layers: Optional[List[str]] = None,
inplace: bool = True,
use_quant_activations: bool = True,
p: int = 0.25,
return_forward_output: bool = False,
act_order: bool = False) -> None:
if not inplace:
model = deepcopy(model)
super().__init__(
model,
group_of_parallel_layers,
inplace,
use_quant_activations,
act_order,
return_forward_output)

self.orig_forward = self.model.forward
self.model.forward = self.catch_stopfwd
self.class_implementation = GPFQ
GPFQ.p = p

def catch_stopfwd(self, *args, **kwargs):
# Collect quant input
try:
self.orig_forward(*args, **kwargs)
except StopFwdException:
pass

# Disable quantization
self.disable_quant_inference.disable_param_quantization(self.model, is_training=False)
self.disable_quant_inference.disable_act_quantization(self.model, is_training=False)
# Collect float input
try:
self.orig_forward(*args, **kwargs)
except StopFwdException:
pass

# Re-enable quantization. If activation quantization is disabled,
# we also disable bias quantization
self.disable_quant_inference.enable_param_quantization(self.model, is_training=False)
if self.use_quant_activations:
self.disable_quant_inference.enable_act_quantization(self.model, is_training=False)
else:
self.disable_quant_inference.disable_bias_quantization(self.model, is_training=False)

if self.return_forward_output:
# If we want to return the output of the network, we need to disable all hooks
for name, gpxq_class in self.gpxq_layers.items():
gpxq_class.disable_pre_forward_hook = True
out = self.orig_forward(*args, **kwargs)
for name, gpxq_class in self.gpxq_layers.items():
gpxq_class.disable_pre_forward_hook = False
return out


class GPFQ(GPxQ):
"""
Based on https://github.com/YixuanSeanZhou/Quantized_Neural_Nets/tree/main
"""
p = 0.25

def __init__(self, layer, name, act_order, parallel_layers=1) -> None:

if act_order:
raise ValueError("Act_order is not supported in GPFQ")

super().__init__(layer, name, act_order, parallel_layers)
self.float_input = None
self.quantized_input = None
self.index_computed = False
self.p = GPFQ.p

def update_batch(self, module, input, current_layer):
if self.disable_pre_forward_hook:
return input

# Update reference to current layer
current_layer.layer_names.add(self.name)
is_quant_disabled = module.weight_quant.disable_quant

inp = self.process_input(input)
batch_size = inp.shape[0]

# Preprocess the input to compute the Hessian
if isinstance(self.layer, qnn.QuantLinear):
if len(inp.shape) > 2:
inp = inp.reshape((-1, sum(inp.shape[2:])))
# For QuantLinear layer, groups will be 1
inp_processed = inp.unsqueeze(0)

if isinstance(self.layer, SUPPORTED_CONV_OP):
# Pick the correct unfoldNd class
if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)):
unfold_impl = unfoldNd.UnfoldTransposeNd
else:
unfold_impl = unfoldNd.UnfoldNd

unfold = unfold_impl(
self.layer.kernel_size,
dilation=self.layer.dilation,
padding=self.layer.padding,
stride=self.layer.kernel_size)

# Split input based on how many groups in convolution
inp_by_group = torch.chunk(inp, self.groups, 1)
inp_processed = []
# Preprocess input by group
for i, inp in enumerate(inp_by_group):

inp = unfold(inp)

batch_size, num_blocks = inp.shape[0], inp.shape[-1]
inp = torch.transpose(inp, 1, 2) # shape (B, L, C*kernel_size[0]*kernel_size[1])
inp = inp.reshape(-1, inp.size(-1)) # shape (B*L, C*kernel_size[0]*kernel_size[1])

if not self.index_computed:
self.index_computed = True
self.rand_indices = np.concatenate([
np.random.choice(
np.arange(num_blocks * i, num_blocks * (i + 1)),
size=int(
self.p * num_blocks + 1 if self.p != 1 else self.p * num_blocks))
for i in range(batch_size)]) # need to define self.p (probability)

indexes = self.rand_indices
if np.max(self.rand_indices) > inp.shape[0]:
indexes = self.rand_indices < inp.shape[0]
indexes = self.rand_indices[indexes]

inp = inp[indexes]
inp_processed.append(inp)
inp_processed = torch.stack(inp_processed)

if is_quant_disabled:
if self.float_input is None:
self.float_input = inp_processed
else:
self.float_input = torch.cat([self.float_input, inp_processed], dim=1)
else:
if self.quantized_input is None:
self.quantized_input = inp_processed
else:
self.quantized_input = torch.cat([self.quantized_input, inp_processed], dim=1)
# If we are executing GPFQ with group of parallel layers, we keep track of how many forward
# we executed. Once we executed as many as the number of parallel_layers, we raise
# StopFwdException
current_layer.forward_count += 1
if current_layer.forward_count == len(self.parallel_layers):
current_layer.forward_count = 0
raise StopFwdException

def single_layer_update(self):
weight = self.layer.weight.data
dev = weight.device
dtype = weight.dtype
if isinstance(self.layer, SUPPORTED_CONV_OP):
if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)):
weight = weight.transpose(1, 0) # This performs a view
weight = weight.flatten(1)
weight = weight.view(self.groups, -1, weight.shape[-1]) # [Groups, OC/Groups, IC]
U = torch.zeros(
weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, dtype=dtype)
self.float_input = self.float_input.to(dev)
self.quantized_input = self.quantized_input.to(dev)
permutation_list = [torch.tensor(range(weight.shape[-1]))]
for t in range(weight.shape[-1]):
for group_index in range(self.groups):
U[group_index] += torch.matmul(
weight[group_index, :, t].unsqueeze(1),
self.float_input[group_index, :,
t].unsqueeze(0)) #[OC/Groups, 1] * [1, INSHAPE[1]]
norm = torch.linalg.norm(self.quantized_input[group_index, :, t], 2) ** 2
if norm > 0:
q_arg = U[group_index].matmul(self.quantized_input[group_index, :, t]) / norm
else:
q_arg = torch.zeros_like(U[group_index, :, 0])

weight[group_index, :, t] = q_arg
q = self.get_quant_weights(t, 0, permutation_list)
for group_index in range(self.groups):
U[group_index] -= torch.matmul(
q[group_index].unsqueeze(1),
self.quantized_input[group_index, :, t].unsqueeze(0))

del self.float_input
del self.quantized_input
Loading
Loading