-
Notifications
You must be signed in to change notification settings - Fork 197
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feat (examples/a2q): adding CIFAR10 example (#813)
* Feat (examples/a2q): adding CIFAR10 example * Feat (examples/a2q): adding A2Q+ examples * Adding descriptions to argparser * Creating evaluation script and README doc * Adding activation calibration * Feat (tests): adding tests for example * Fix (examples/a2q): updating file header * Fix (examples/a2q): adding input assertion * Updating paths to reflect tag name
- Loading branch information
Showing
9 changed files
with
984 additions
and
0 deletions.
There are no files selected for viewing
28 changes: 28 additions & 0 deletions
28
src/brevitas_examples/imagenet_classification/a2q/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# Integer-Quantized Image Classification Models Trained on CIFAR10 with Brevitas | ||
|
||
This directory contains scripts demonstrating how to train integer-quantized image classification models using accumulator-aware quantization (A2Q) as proposed in our ICCV 2023 paper "[A2Q: Accumulator-Aware Quantization with Guaranteed Overflow Avoidance](https://arxiv.org/abs/2308.13504)". | ||
Code is also provided to demonstrate A2Q+ as proposed in our arXiv paper "[A2Q+: Improving Accumulator-Aware Weight Quantization](https://arxiv.org/abs/2401.10432)", where we introduce the zero-centered weight quantizer (i.e., `AccumulatorAwareZeroCenterWeightQuant`) as well as the Euclidean projection-based weight initialization (EP-init). | ||
|
||
## Experiments | ||
|
||
All models are trained on the CIFAR10 dataset. | ||
Input images are normalized to have unit mean and variance. | ||
During training, random cropping is applied, along with random horizontal flips. | ||
All residual connections are quantized to the specified activation bit width. | ||
|
||
|
||
| Model Name | Weight Quantization | Activation Quantization | Target Accumulator | Top-1 Accuracy (%) | | ||
|-----------------------------|----------------|---------------------|-------------------------|----------------------------| | ||
| [float_resnet18](https://github.com/Xilinx/brevitas/releases/download/a2q_cifar10_r1/float_resnet18-1d98d23a.pth) | float32 | float32 | float32 | 95.0 | | ||
|| | ||
| [quant_resnet18_w4a4_a2q_16b](https://github.com/Xilinx/brevitas/releases/download/a2q_cifar10_r1/quant_resnet18_w4a4_a2q_16b-d0af41f1.pth) | int4 | uint4 | int16 | 94.2 | | ||
| [quant_resnet18_w4a4_a2q_15b](https://github.com/Xilinx/brevitas/releases/download/a2q_cifar10_r1/quant_resnet18_w4a4_a2q_15b-0d5bf266.pth) | int4 | uint4 | int15 | 94.2 | | ||
| [quant_resnet18_w4a4_a2q_14b](https://github.com/Xilinx/brevitas/releases/download/a2q_cifar10_r1/quant_resnet18_w4a4_a2q_14b-267f237b.pth) | int4 | uint4 | int14 | 92.6 | | ||
| [quant_resnet18_w4a4_a2q_13b](https://github.com/Xilinx/brevitas/releases/download/a2q_cifar10_r1/quant_resnet18_w4a4_a2q_13b-8c31a2b1.pth) | int4 | uint4 | int13 | 89.8 | | ||
| [quant_resnet18_w4a4_a2q_12b](https://github.com/Xilinx/brevitas/releases/download/a2q_cifar10_r1/quant_resnet18_w4a4_a2q_12b-8a440436.pth) | int4 | uint4 | int12 | 83.9 | | ||
|| | ||
| [quant_resnet18_w4a4_a2q_plus_16b](https://github.com/Xilinx/brevitas/releases/download/a2q_cifar10_r1/quant_resnet18_w4a4_a2q_plus_16b-19973380.pth) | int4 | uint4 | int16 | 94.2 | | ||
| [quant_resnet18_w4a4_a2q_plus_15b](https://github.com/Xilinx/brevitas/releases/download/a2q_cifar10_r1/quant_resnet18_w4a4_a2q_plus_15b-3c89551a.pth) | int4 | uint4 | int15 | 94.1 | | ||
| [quant_resnet18_w4a4_a2q_plus_14b](https://github.com/Xilinx/brevitas/releases/download/a2q_cifar10_r1/quant_resnet18_w4a4_a2q_plus_14b-5a2d11aa.pth) | int4 | uint4 | int14 | 94.1 | | ||
| [quant_resnet18_w4a4_a2q_plus_13b](https://github.com/Xilinx/brevitas/releases/download/a2q_cifar10_r1/quant_resnet18_w4a4_a2q_plus_13b-332aaf81.pth) | int4 | uint4 | int13 | 92.8 | | ||
| [quant_resnet18_w4a4_a2q_plus_12b](https://github.com/Xilinx/brevitas/releases/download/a2q_cifar10_r1/quant_resnet18_w4a4_a2q_plus_12b-d69f003b.pth) | int4 | uint4 | int12 | 90.6 | |
2 changes: 2 additions & 0 deletions
2
src/brevitas_examples/imagenet_classification/a2q/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. | ||
# SPDX-License-Identifier: BSD-3-Clause |
107 changes: 107 additions & 0 deletions
107
src/brevitas_examples/imagenet_classification/a2q/a2q_evaluate_models.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
|
||
import argparse | ||
from hashlib import sha256 | ||
import os | ||
import random | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
|
||
import brevitas.config as config | ||
from brevitas.export import export_qonnx | ||
import brevitas_examples.imagenet_classification.a2q.utils as utils | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--data-root", type=str, required=True, help="Directory where the dataset is stored.") | ||
parser.add_argument( | ||
"--model-name", | ||
type=str, | ||
default="quant_resnet18_w4a4_a2q_32b", | ||
help="Name of model to train. Default: 'quant_resnet18_w4a4_a2q_32b'", | ||
choices=utils.model_impl.keys()) | ||
parser.add_argument( | ||
"--save-path", | ||
type=str, | ||
default="outputs/", | ||
help="Directory where to save checkpoints. Default: 'outputs/'") | ||
parser.add_argument( | ||
"--load-from-path", | ||
type=str, | ||
default=None, | ||
help="Optional local path to load torch checkpoint from. Default: None") | ||
parser.add_argument( | ||
"--num-workers", | ||
type=int, | ||
default=0, | ||
help="Number of workers for the dataloader to use. Default: 0") | ||
parser.add_argument( | ||
"--pin-memory", | ||
action="store_true", | ||
default=False, | ||
help="If true, pin memory for the dataloader.") | ||
parser.add_argument( | ||
"--batch-size", type=int, default=512, help="Batch size for the dataloader. Default: 512") | ||
parser.add_argument( | ||
"--save-torch-model", | ||
action="store_true", | ||
default=False, | ||
help="If true, save torch model to specified save path.") | ||
parser.add_argument( | ||
"--export-to-qonnx", action="store_true", default=False, help="If true, export model to QONNX.") | ||
|
||
SEED = 0 | ||
random.seed(SEED) | ||
np.random.seed(SEED) | ||
torch.manual_seed(SEED) | ||
|
||
# create a random input for graph tracing | ||
random_inp = torch.randn(1, 3, 32, 32) | ||
|
||
if __name__ == "__main__": | ||
|
||
args = parser.parse_args() | ||
|
||
config.JIT_ENABLED = not args.export_to_qonnx | ||
|
||
# Initialize dataloaders | ||
print(f"Loading CIFAR10 dataset from {args.data_root}...") | ||
trainloader, testloader = utils.get_cifar10_dataloaders( | ||
data_root=args.data_root, | ||
batch_size_train=args.batch_size, # does not matter here | ||
batch_size_test=args.batch_size, | ||
num_workers=args.num_workers, | ||
pin_memory=args.pin_memory) | ||
|
||
# if load-from-path is not specified, then use the pre-trained checkpoint | ||
model = utils.get_model_by_name(args.model_name, pretrained=args.load_from_path is None) | ||
if args.load_from_path is not None: | ||
# note that if you used bias correction, you may need to prepare the model for the | ||
# new biases that were introduced. See `utils.get_model_by_name` for more details. | ||
state_dict = torch.load(args.load_from_path, map_location="cpu") | ||
model.load_state_dict(state_dict) | ||
criterion = nn.CrossEntropyLoss() | ||
|
||
top_1, top_5, loss = utils.evaluate_topk_accuracies(testloader, model, criterion) | ||
print(f"Final top_1={top_1:.1%}, top_5={top_5:.1%}, loss={loss:.3f}") | ||
|
||
# save checkpoint | ||
os.makedirs(args.save_path, exist_ok=True) | ||
if args.save_torch_model: | ||
ckpt_path = f"{args.save_path}/{args.model_name}.pth" | ||
torch.save(model.state_dict(), ckpt_path) | ||
with open(ckpt_path, "rb") as _file: | ||
bytes = _file.read() | ||
model_tag = sha256(bytes).hexdigest()[:8] | ||
new_ckpt_path = f"{args.save_path}/{args.model_name}-{model_tag}.pth" | ||
os.rename(ckpt_path, new_ckpt_path) | ||
print(f"Saved model checkpoint to: {new_ckpt_path}") | ||
|
||
if args.export_to_qonnx: | ||
export_qonnx( | ||
model.cpu(), | ||
input_t=random_inp.cpu(), | ||
export_path=f"{args.save_path}/{args.model_name}-{model_tag}.onnx") |
204 changes: 204 additions & 0 deletions
204
src/brevitas_examples/imagenet_classification/a2q/a2q_train_models.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
|
||
import argparse | ||
import copy | ||
from hashlib import sha256 | ||
import os | ||
import random | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
import torch.optim.lr_scheduler as lrs | ||
|
||
import brevitas.config as config | ||
from brevitas.export import export_qonnx | ||
import brevitas_examples.imagenet_classification.a2q.utils as utils | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--data-root", type=str, required=True, help="Directory where the dataset is stored.") | ||
parser.add_argument( | ||
"--model-name", | ||
type=str, | ||
default="quant_resnet18_w4a4_a2q_32b", | ||
help="Name of model to train. Default: 'quant_resnet18_w4a4_a2q_32b'", | ||
choices=utils.model_impl.keys()) | ||
parser.add_argument( | ||
"--save-path", | ||
type=str, | ||
default="outputs/", | ||
help="Directory where to save checkpoints. Default: 'outputs/'") | ||
parser.add_argument( | ||
"--num-workers", | ||
type=int, | ||
default=0, | ||
help="Number of workers for the dataloader to use. Default: 0") | ||
parser.add_argument( | ||
"--pin-memory", | ||
action="store_true", | ||
default=False, | ||
help="If true, pin memory for the dataloader.") | ||
parser.add_argument( | ||
"--batch-size-train", | ||
type=int, | ||
default=256, | ||
help="Batch size for the training dataloader. Default: 256") | ||
parser.add_argument( | ||
"--batch-size-test", | ||
type=int, | ||
default=512, | ||
help="Batch size for the testing dataloader. Default: 512") | ||
parser.add_argument( | ||
"--batch-size-calibration", | ||
type=int, | ||
default=256, | ||
help="Batch size for the calibration dataloader. Default: 256") | ||
parser.add_argument( | ||
"--calibration-samples", | ||
type=int, | ||
default=1000, | ||
help="Number of samples to use for calibration. Default: 1000") | ||
parser.add_argument( | ||
"--weight-decay", | ||
type=float, | ||
default=1e-5, | ||
help="Weight decay for the Adam optimizer. Default: 0.00001") | ||
parser.add_argument( | ||
"--lr-init", type=float, default=1e-3, help="Initial learning rate. Default: 0.001") | ||
parser.add_argument( | ||
"--lr-step-size", | ||
type=int, | ||
default=30, | ||
help="Step size for the learning rate scheduler. Default: 30") | ||
parser.add_argument( | ||
"--lr-gamma", | ||
type=float, | ||
default=0.1, | ||
help="Default gamma for the learning rate scheduler. Default: 0.1") | ||
parser.add_argument( | ||
"--total-epochs", type=int, default=90, help="Total epoch to train the model for. Default: 90") | ||
parser.add_argument( | ||
"--from-float-checkpoint", | ||
action="store_true", | ||
default=False, | ||
help="If true, use a pre-trained floating-point checkpoint.") | ||
parser.add_argument( | ||
"--save-torch-model", | ||
action="store_true", | ||
default=False, | ||
help="If true, save torch model to specified save path.") | ||
parser.add_argument( | ||
"--apply-act-calibration", | ||
action="store_true", | ||
default=False, | ||
help="If true, apply activation calibration to the quantized model.") | ||
parser.add_argument( | ||
"--apply-bias-correction", | ||
action="store_true", | ||
default=False, | ||
help="If true, apply bias correction to the quantized model.") | ||
parser.add_argument( | ||
"--apply-ep-init", | ||
action="store_true", | ||
default=False, | ||
help="If true, apply EP-init to the quantized model.") | ||
parser.add_argument( | ||
"--export-to-qonnx", action="store_true", default=False, help="If true, export model to QONNX.") | ||
|
||
# ignore missing keys when loading pre-trained checkpoint | ||
config.IGNORE_MISSING_KEYS = True | ||
|
||
SEED = 0 | ||
random.seed(SEED) | ||
np.random.seed(SEED) | ||
torch.manual_seed(SEED) | ||
|
||
# create a random input for graph tracing | ||
random_inp = torch.randn(1, 3, 32, 32) | ||
|
||
if __name__ == "__main__": | ||
|
||
args = parser.parse_args() | ||
|
||
config.JIT_ENABLED = not args.export_to_qonnx | ||
|
||
# Initialize dataloaders | ||
print(f"Loading CIFAR10 dataset from {args.data_root}...") | ||
trainloader, testloader = utils.get_cifar10_dataloaders( | ||
data_root=args.data_root, | ||
batch_size_train=args.batch_size_train, | ||
batch_size_test=args.batch_size_test, | ||
num_workers=args.num_workers, | ||
pin_memory=args.pin_memory) | ||
calibloader = utils.create_calibration_dataloader( | ||
dataset=trainloader.dataset, | ||
batch_size=args.batch_size_calibration, | ||
num_workers=args.num_workers, | ||
subset_size=args.calibration_samples) | ||
|
||
model = utils.get_model_by_name( | ||
args.model_name, init_from_float_checkpoint=args.from_float_checkpoint) | ||
criterion = nn.CrossEntropyLoss() | ||
optimizer = optim.SGD( | ||
utils.filter_params(model.named_parameters(), args.weight_decay), | ||
lr=args.lr_init, | ||
weight_decay=args.weight_decay) | ||
scheduler = lrs.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) | ||
|
||
# Calibrate the quant model on the calibration dataset | ||
if args.apply_ep_init: | ||
print("Applying EP-init:") | ||
model = utils.apply_ep_init(model, random_inp) | ||
|
||
# Calibrate the quant model on the calibration dataset | ||
if args.apply_act_calibration: | ||
print("Applying activation calibration:") | ||
utils.apply_act_calibrate(calibloader, model) | ||
|
||
if args.apply_bias_correction: | ||
print("Applying bias correction:") | ||
utils.apply_bias_correction(calibloader, model) | ||
|
||
best_top_1, best_weights = 0., copy.deepcopy(model.state_dict()) | ||
for epoch in range(args.total_epochs): | ||
|
||
train_loss = utils.train_for_epoch(trainloader, model, criterion, optimizer) | ||
test_top_1, test_top_5, test_loss = utils.evaluate_topk_accuracies(testloader, model, criterion) | ||
scheduler.step() | ||
|
||
print( | ||
f"[Epoch {epoch:03d}]", | ||
f"train_loss={train_loss:.3f},", | ||
f"test_loss={test_loss:.3f},", | ||
f"test_top_1={test_top_1:.1%},", | ||
f"test_top_5={test_top_5:.1%}", | ||
sep=" ") | ||
|
||
if test_top_1 >= best_top_1: | ||
best_weights = copy.deepcopy(model.state_dict()) | ||
best_top_1 = test_top_1 | ||
|
||
model.load_state_dict(best_weights) | ||
top_1, top_5, loss = utils.evaluate_topk_accuracies(testloader, model, criterion) | ||
print(f"Final: top_1={top_1:.1%}, top_5={top_5:.1%}, loss={loss:.3f}") | ||
|
||
# save checkpoint | ||
os.makedirs(args.save_path, exist_ok=True) | ||
if args.save_torch_model: | ||
ckpt_path = f"{args.save_path}/{args.model_name}.pth" | ||
torch.save(best_weights, ckpt_path) | ||
with open(ckpt_path, "rb") as _file: | ||
bytes = _file.read() | ||
model_tag = sha256(bytes).hexdigest()[:8] | ||
new_ckpt_path = f"{args.save_path}/{args.model_name}-{model_tag}.pth" | ||
os.rename(ckpt_path, new_ckpt_path) | ||
print(f"Saved model checkpoint to {new_ckpt_path}") | ||
|
||
if args.export_to_qonnx: | ||
export_qonnx( | ||
model.cpu(), | ||
input_t=random_inp.cpu(), | ||
export_path=f"{args.save_path}/{args.model_name}-{model_tag}.onnx") |
Oops, something went wrong.