From 8bf6dd6a12acfe731370679c66d7546efaa5414a Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 31 Aug 2021 15:44:20 -0700 Subject: [PATCH] feat: Add VGG QAT sample notebook which demonstrates end-end workflow for QAT models Signed-off-by: Dheeraj Peri --- notebooks/vgg_qat.ipynb | 1074 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 1074 insertions(+) create mode 100644 notebooks/vgg_qat.ipynb diff --git a/notebooks/vgg_qat.ipynb b/notebooks/vgg_qat.ipynb new file mode 100644 index 0000000000..3e4e739c3d --- /dev/null +++ b/notebooks/vgg_qat.ipynb @@ -0,0 +1,1074 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "86b49f2a", + "metadata": {}, + "source": [ + "# Deploying Quantization Aware Trained models in INT8 using TRTorch" + ] + }, + { + "cell_type": "markdown", + "id": "feb10417", + "metadata": {}, + "source": [ + "## Overview\n", + "\n", + "Quantization Aware training (QAT) simulates quantization during training by quantizing weights and activation layers. This will help to reduce the loss in accuracy when we convert the network trained in FP32 to INT8 for faster inference. QAT introduces additional nodes in the graph which will be used to learn the dynamic ranges of weights and activation layers. In this notebook, we illustrate the following steps from training to inference of a QAT model in TRTorch.\n", + "\n", + "1. [Requirements](#1)\n", + "2. [VGG16 Overview](#2)\n", + "3. [Training a baseline VGG16 model](#3)\n", + "4. [Apply Quantization](#4)\n", + "5. [Model calibration](#5)\n", + "6. [Quantization Aware training](#6)\n", + "7. [Export to Torchscript](#7)\n", + "8. [Inference using TRTorch](#8)\n", + "8. [References](#8)" + ] + }, + { + "cell_type": "markdown", + "id": "2b3f1107", + "metadata": {}, + "source": [ + "\n", + "## 1. Requirements\n", + "Please install the required dependencies and import these libraries accordingly" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "59d11964", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: Logging before flag parsing goes to stderr.\n", + "E0831 15:09:13.151450 140586428176192 amp_wrapper.py:31] AMP is not avaialble.\n" + ] + } + ], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torch.optim as optim\n", + "import torch.utils.data as data\n", + "import torchvision.transforms as transforms\n", + "import torchvision.datasets as datasets\n", + "import trtorch\n", + "\n", + "from torch.utils.tensorboard import SummaryWriter\n", + "\n", + "from pytorch_quantization import nn as quant_nn\n", + "from pytorch_quantization import quant_modules\n", + "from pytorch_quantization.tensor_quant import QuantDescriptor\n", + "from pytorch_quantization import calib\n", + "from tqdm import tqdm\n", + "\n", + "import os\n", + "import sys\n", + "sys.path.insert(0, \"../examples/int8/training/vgg16\")\n", + "from vgg16 import vgg16\n" + ] + }, + { + "cell_type": "markdown", + "id": "88319c40", + "metadata": {}, + "source": [ + "\n", + "## 2. VGG16 Overview\n", + "### Very Deep Convolutional Networks for Large-Scale Image Recognition\n", + "VGG is one of the earliest family of image classification networks that first used small (3x3) convolution filters and achieved significant improvements on ImageNet recognition challenge. The network architecture looks as follows\n", + "\n", + " " + ] + }, + { + "cell_type": "markdown", + "id": "57db62be", + "metadata": {}, + "source": [ + "\n", + "## 3. Training a baseline VGG16 model\n", + "We train VGG16 on CIFAR10 dataset. Define training and testing datasets and dataloaders. This will download the CIFAR 10 data in your `data` directory. Data preprocessing is performed using `torchvision` transforms. " + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "d799dc37", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Files already downloaded and verified\n", + "Files already downloaded and verified\n" + ] + } + ], + "source": [ + "classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n", + "\n", + "# ========== Define Training dataset and dataloaders =============#\n", + "training_dataset = datasets.CIFAR10(root='./data',\n", + " train=True,\n", + " download=True,\n", + " transform=transforms.Compose([\n", + " transforms.RandomCrop(32, padding=4),\n", + " transforms.RandomHorizontalFlip(),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n", + " ]))\n", + "\n", + "training_dataloader = torch.utils.data.DataLoader(training_dataset,\n", + " batch_size=32,\n", + " shuffle=True,\n", + " num_workers=2)\n", + "\n", + "# ========== Define Testing dataset and dataloaders =============#\n", + "testing_dataset = datasets.CIFAR10(root='./data',\n", + " train=False,\n", + " download=True,\n", + " transform=transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n", + " ]))\n", + "\n", + "testing_dataloader = torch.utils.data.DataLoader(testing_dataset,\n", + " batch_size=16,\n", + " shuffle=False,\n", + " num_workers=2)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "c7127092", + "metadata": {}, + "outputs": [], + "source": [ + "def train(model, dataloader, crit, opt, epoch):\n", + "# global writer\n", + " model.train()\n", + " running_loss = 0.0\n", + " for batch, (data, labels) in enumerate(dataloader):\n", + " data, labels = data.cuda(), labels.cuda(non_blocking=True)\n", + " opt.zero_grad()\n", + " out = model(data)\n", + " loss = crit(out, labels)\n", + " loss.backward()\n", + " opt.step()\n", + "\n", + " running_loss += loss.item()\n", + " if batch % 500 == 499:\n", + " print(\"Batch: [%5d | %5d] loss: %.3f\" % (batch + 1, len(dataloader), running_loss / 100))\n", + " running_loss = 0.0\n", + " \n", + "def test(model, dataloader, crit, epoch):\n", + " global writer\n", + " global classes\n", + " total = 0\n", + " correct = 0\n", + " loss = 0.0\n", + " class_probs = []\n", + " class_preds = []\n", + " model.eval()\n", + " with torch.no_grad():\n", + " for data, labels in dataloader:\n", + " data, labels = data.cuda(), labels.cuda(non_blocking=True)\n", + " out = model(data)\n", + " loss += crit(out, labels)\n", + " preds = torch.max(out, 1)[1]\n", + " class_probs.append([F.softmax(i, dim=0) for i in out])\n", + " class_preds.append(preds)\n", + " total += labels.size(0)\n", + " correct += (preds == labels).sum().item()\n", + "\n", + " test_probs = torch.cat([torch.stack(batch) for batch in class_probs])\n", + " test_preds = torch.cat(class_preds)\n", + "\n", + " return loss / total, correct / total\n", + "\n", + "def save_checkpoint(state, ckpt_path=\"checkpoint.pth\"):\n", + " torch.save(state, ckpt_path)\n", + " print(\"Checkpoint saved\")" + ] + }, + { + "cell_type": "markdown", + "id": "e562f7b6", + "metadata": {}, + "source": [ + "*Define the VGG model that we are going to perfom QAT on.*" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "5233b3ee", + "metadata": {}, + "outputs": [], + "source": [ + "# CIFAR 10 has 10 classes\n", + "model = vgg16(num_classes=len(classes), init_weights=False)\n", + "model = model.cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "2598ea9c", + "metadata": {}, + "outputs": [], + "source": [ + "# Declare Learning rate\n", + "lr = 0.1\n", + "state = {}\n", + "state[\"lr\"] = lr\n", + "\n", + "# Use cross entropy loss for classification and SGD optimizer\n", + "crit = nn.CrossEntropyLoss()\n", + "opt = optim.SGD(model.parameters(), lr=state[\"lr\"], momentum=0.9, weight_decay=1e-4)\n", + "\n", + "\n", + "# Adjust learning rate based on epoch number\n", + "def adjust_lr(optimizer, epoch):\n", + " global state\n", + " new_lr = lr * (0.5**(epoch // 12)) if state[\"lr\"] > 1e-7 else state[\"lr\"]\n", + " if new_lr != state[\"lr\"]:\n", + " state[\"lr\"] = new_lr\n", + " print(\"Updating learning rate: {}\".format(state[\"lr\"]))\n", + " for param_group in optimizer.param_groups:\n", + " param_group[\"lr\"] = state[\"lr\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "7ca88929", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: [ 1 / 25] LR: 0.100000\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/dperi/Downloads/py3/lib/python3.6/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /pytorch/c10/core/TensorImpl.h:1156.)\n", + " return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Batch: [ 500 | 1563] loss: 12.466\n", + "Batch: [ 1000 | 1563] loss: 10.726\n", + "Batch: [ 1500 | 1563] loss: 10.289\n", + "Test Loss: 0.12190 Test Acc: 19.86%\n", + "Epoch: [ 2 / 25] LR: 0.100000\n", + "Batch: [ 500 | 1563] loss: 10.107\n", + "Batch: [ 1000 | 1563] loss: 9.986\n", + "Batch: [ 1500 | 1563] loss: 9.994\n", + "Test Loss: 0.12230 Test Acc: 21.54%\n", + "Epoch: [ 3 / 25] LR: 0.100000\n", + "Batch: [ 500 | 1563] loss: 9.826\n", + "Batch: [ 1000 | 1563] loss: 9.904\n", + "Batch: [ 1500 | 1563] loss: 9.771\n", + "Test Loss: 0.11709 Test Acc: 22.71%\n", + "Epoch: [ 4 / 25] LR: 0.100000\n", + "Batch: [ 500 | 1563] loss: 9.760\n", + "Batch: [ 1000 | 1563] loss: 9.629\n", + "Batch: [ 1500 | 1563] loss: 9.642\n", + "Test Loss: 0.11945 Test Acc: 23.89%\n", + "Epoch: [ 5 / 25] LR: 0.100000\n", + "Batch: [ 500 | 1563] loss: 9.590\n", + "Batch: [ 1000 | 1563] loss: 9.489\n", + "Batch: [ 1500 | 1563] loss: 9.468\n", + "Test Loss: 0.11180 Test Acc: 30.01%\n", + "Epoch: [ 6 / 25] LR: 0.100000\n", + "Batch: [ 500 | 1563] loss: 9.281\n", + "Batch: [ 1000 | 1563] loss: 9.057\n", + "Batch: [ 1500 | 1563] loss: 8.957\n", + "Test Loss: 0.11106 Test Acc: 28.03%\n", + "Epoch: [ 7 / 25] LR: 0.100000\n", + "Batch: [ 500 | 1563] loss: 8.799\n", + "Batch: [ 1000 | 1563] loss: 8.808\n", + "Batch: [ 1500 | 1563] loss: 8.647\n", + "Test Loss: 0.10456 Test Acc: 32.25%\n", + "Epoch: [ 8 / 25] LR: 0.100000\n", + "Batch: [ 500 | 1563] loss: 8.672\n", + "Batch: [ 1000 | 1563] loss: 8.478\n", + "Batch: [ 1500 | 1563] loss: 8.522\n", + "Test Loss: 0.10404 Test Acc: 32.40%\n", + "Epoch: [ 9 / 25] LR: 0.100000\n", + "Batch: [ 500 | 1563] loss: 8.422\n", + "Batch: [ 1000 | 1563] loss: 8.290\n", + "Batch: [ 1500 | 1563] loss: 8.474\n", + "Test Loss: 0.10282 Test Acc: 41.11%\n", + "Epoch: [ 10 / 25] LR: 0.100000\n", + "Batch: [ 500 | 1563] loss: 8.131\n", + "Batch: [ 1000 | 1563] loss: 8.005\n", + "Batch: [ 1500 | 1563] loss: 8.074\n", + "Test Loss: 0.09473 Test Acc: 38.91%\n", + "Epoch: [ 11 / 25] LR: 0.100000\n", + "Batch: [ 500 | 1563] loss: 8.132\n", + "Batch: [ 1000 | 1563] loss: 8.047\n", + "Batch: [ 1500 | 1563] loss: 7.941\n", + "Test Loss: 0.09928 Test Acc: 41.69%\n", + "Epoch: [ 12 / 25] LR: 0.100000\n", + "Batch: [ 500 | 1563] loss: 7.911\n", + "Batch: [ 1000 | 1563] loss: 7.974\n", + "Batch: [ 1500 | 1563] loss: 7.871\n", + "Test Loss: 0.10598 Test Acc: 38.90%\n", + "Updating learning rate: 0.05\n", + "Epoch: [ 13 / 25] LR: 0.050000\n", + "Batch: [ 500 | 1563] loss: 6.981\n", + "Batch: [ 1000 | 1563] loss: 6.543\n", + "Batch: [ 1500 | 1563] loss: 6.377\n", + "Test Loss: 0.07362 Test Acc: 53.72%\n", + "Epoch: [ 14 / 25] LR: 0.050000\n", + "Batch: [ 500 | 1563] loss: 6.208\n", + "Batch: [ 1000 | 1563] loss: 6.113\n", + "Batch: [ 1500 | 1563] loss: 6.016\n", + "Test Loss: 0.07922 Test Acc: 55.78%\n", + "Epoch: [ 15 / 25] LR: 0.050000\n", + "Batch: [ 500 | 1563] loss: 5.945\n", + "Batch: [ 1000 | 1563] loss: 5.726\n", + "Batch: [ 1500 | 1563] loss: 5.568\n", + "Test Loss: 0.05914 Test Acc: 65.33%\n", + "Epoch: [ 16 / 25] LR: 0.050000\n", + "Batch: [ 500 | 1563] loss: 5.412\n", + "Batch: [ 1000 | 1563] loss: 5.356\n", + "Batch: [ 1500 | 1563] loss: 5.143\n", + "Test Loss: 0.05833 Test Acc: 68.91%\n", + "Epoch: [ 17 / 25] LR: 0.050000\n", + "Batch: [ 500 | 1563] loss: 5.096\n", + "Batch: [ 1000 | 1563] loss: 5.064\n", + "Batch: [ 1500 | 1563] loss: 4.962\n", + "Test Loss: 0.05291 Test Acc: 71.72%\n", + "Epoch: [ 18 / 25] LR: 0.050000\n", + "Batch: [ 500 | 1563] loss: 4.958\n", + "Batch: [ 1000 | 1563] loss: 4.887\n", + "Batch: [ 1500 | 1563] loss: 4.711\n", + "Test Loss: 0.05003 Test Acc: 73.61%\n", + "Epoch: [ 19 / 25] LR: 0.050000\n", + "Batch: [ 500 | 1563] loss: 4.651\n", + "Batch: [ 1000 | 1563] loss: 4.567\n", + "Batch: [ 1500 | 1563] loss: 4.603\n", + "Test Loss: 0.05046 Test Acc: 73.80%\n", + "Epoch: [ 20 / 25] LR: 0.050000\n", + "Batch: [ 500 | 1563] loss: 4.467\n", + "Batch: [ 1000 | 1563] loss: 4.399\n", + "Batch: [ 1500 | 1563] loss: 4.310\n", + "Test Loss: 0.05038 Test Acc: 74.45%\n", + "Epoch: [ 21 / 25] LR: 0.050000\n", + "Batch: [ 500 | 1563] loss: 4.226\n", + "Batch: [ 1000 | 1563] loss: 4.196\n", + "Batch: [ 1500 | 1563] loss: 4.169\n", + "Test Loss: 0.05287 Test Acc: 71.18%\n", + "Epoch: [ 22 / 25] LR: 0.050000\n", + "Batch: [ 500 | 1563] loss: 4.120\n", + "Batch: [ 1000 | 1563] loss: 4.035\n", + "Batch: [ 1500 | 1563] loss: 4.018\n", + "Test Loss: 0.06157 Test Acc: 70.29%\n", + "Epoch: [ 23 / 25] LR: 0.050000\n", + "Batch: [ 500 | 1563] loss: 3.915\n", + "Batch: [ 1000 | 1563] loss: 3.968\n", + "Batch: [ 1500 | 1563] loss: 3.989\n", + "Test Loss: 0.04128 Test Acc: 79.01%\n", + "Epoch: [ 24 / 25] LR: 0.050000\n", + "Batch: [ 500 | 1563] loss: 3.871\n", + "Batch: [ 1000 | 1563] loss: 3.800\n", + "Batch: [ 1500 | 1563] loss: 3.871\n", + "Test Loss: 0.04785 Test Acc: 75.77%\n", + "Updating learning rate: 0.025\n", + "Epoch: [ 25 / 25] LR: 0.025000\n", + "Batch: [ 500 | 1563] loss: 3.141\n", + "Batch: [ 1000 | 1563] loss: 2.979\n", + "Batch: [ 1500 | 1563] loss: 2.874\n", + "Test Loss: 0.03345 Test Acc: 83.15%\n", + "Checkpoint saved\n" + ] + } + ], + "source": [ + "# Train the model for 25 epochs to get ~80% accuracy.\n", + "num_epochs=25\n", + "for epoch in range(num_epochs):\n", + " adjust_lr(opt, epoch)\n", + " print('Epoch: [%5d / %5d] LR: %f' % (epoch + 1, num_epochs, state[\"lr\"]))\n", + "\n", + " train(model, training_dataloader, crit, opt, epoch)\n", + " test_loss, test_acc = test(model, testing_dataloader, crit, epoch)\n", + "\n", + " print(\"Test Loss: {:.5f} Test Acc: {:.2f}%\".format(test_loss, 100 * test_acc))\n", + " \n", + "save_checkpoint({'epoch': epoch + 1,\n", + " 'model_state_dict': model.state_dict(),\n", + " 'acc': test_acc,\n", + " 'opt_state_dict': opt.state_dict(),\n", + " 'state': state},\n", + " ckpt_path=\"vgg16_base_ckpt\")" + ] + }, + { + "cell_type": "markdown", + "id": "2b5b7a43", + "metadata": {}, + "source": [ + "\n", + "## 4. Apply Quantization" + ] + }, + { + "cell_type": "markdown", + "id": "c4057574", + "metadata": {}, + "source": [ + "`quant_modules.initialize()` will ensure quantized version of modules will be called instead of original modules. For example, when you define a model with convolution, linear, pooling layers, `QuantConv2d`, `QuantLinear` and `QuantPooling` will be called. `QuantConv2d` basically wraps quantizer nodes around inputs and weights of regular `Conv2d`. Please refer to all the quantized modules in pytorch-quantization toolkit for more information. A `QuantConv2d` is represented in `pytorch-quantization` toolkit as follows.\n", + "\n", + "```\n", + "def forward(self, input):\n", + " # the actual quantization happens in the next level of the class hierarchy\n", + " quant_input, quant_weight = self._quant(input)\n", + "\n", + " if self.padding_mode == 'circular':\n", + " expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2,\n", + " (self.padding[0] + 1) // 2, self.padding[0] // 2)\n", + " output = F.conv2d(F.pad(quant_input, expanded_padding, mode='circular'),\n", + " quant_weight, self.bias, self.stride,\n", + " _pair(0), self.dilation, self.groups)\n", + " else:\n", + " output = F.conv2d(quant_input, quant_weight, self.bias, self.stride, self.padding, self.dilation,\n", + " self.groups)\n", + "\n", + " return output\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "5556aadf", + "metadata": {}, + "outputs": [], + "source": [ + "quant_modules.initialize()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "c8854f15", + "metadata": {}, + "outputs": [], + "source": [ + "# All the regular conv, FC layers will be converted to their quantozed counterparts due to quant_modules.initialize()\n", + "qat_model = vgg16(num_classes=len(classes), init_weights=False)\n", + "qat_model = qat_model.cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "2a1a3f36", + "metadata": {}, + "outputs": [], + "source": [ + "# vgg16_base_ckpt is the checkpoint generated from Step 3 : Training a baseline VGG16 model.\n", + "ckpt = torch.load(\"./vgg16_base_ckpt\")\n", + "modified_state_dict={}\n", + "for key, val in ckpt[\"model_state_dict\"].items():\n", + " # Remove 'module.' from the key names\n", + " if key.startswith('module'):\n", + " modified_state_dict[key[7:]] = val\n", + " else:\n", + " modified_state_dict[key] = val\n", + "\n", + "# Load the pre-trained checkpoint\n", + "qat_model.load_state_dict(modified_state_dict)\n", + "opt.load_state_dict(ckpt[\"opt_state_dict\"])" + ] + }, + { + "cell_type": "markdown", + "id": "05e89b6a", + "metadata": {}, + "source": [ + "\n", + "## 5. Model Calibration" + ] + }, + { + "cell_type": "markdown", + "id": "14aa4656", + "metadata": {}, + "source": [ + "The quantizer nodes introduced in the model around desired layers capture the dynamic range (min_value, max_value) that is observed by the layer. Calibration is the process of computing the dynamic range of these layers by passing calibration data, which is usually a subset of training or validation data. There are different ways of calibration: `max`, `histogram` and `entropy`. We use `max` calibration technique as it is simple and effective. " + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "720b3615", + "metadata": {}, + "outputs": [], + "source": [ + "def compute_amax(model, **kwargs):\n", + " # Load calib result\n", + " for name, module in model.named_modules():\n", + " if isinstance(module, quant_nn.TensorQuantizer):\n", + " if module._calibrator is not None:\n", + " if isinstance(module._calibrator, calib.MaxCalibrator):\n", + " module.load_calib_amax()\n", + " else:\n", + " module.load_calib_amax(**kwargs)\n", + " print(F\"{name:40}: {module}\")\n", + " model.cuda()\n", + "\n", + "def collect_stats(model, data_loader, num_batches):\n", + " \"\"\"Feed data to the network and collect statistics\"\"\"\n", + " # Enable calibrators\n", + " for name, module in model.named_modules():\n", + " if isinstance(module, quant_nn.TensorQuantizer):\n", + " if module._calibrator is not None:\n", + " module.disable_quant()\n", + " module.enable_calib()\n", + " else:\n", + " module.disable()\n", + "\n", + " # Feed data to the network for collecting stats\n", + " for i, (image, _) in tqdm(enumerate(data_loader), total=num_batches):\n", + " model(image.cuda())\n", + " if i >= num_batches:\n", + " break\n", + "\n", + " # Disable calibrators\n", + " for name, module in model.named_modules():\n", + " if isinstance(module, quant_nn.TensorQuantizer):\n", + " if module._calibrator is not None:\n", + " module.enable_quant()\n", + " module.disable_calib()\n", + " else:\n", + " module.enable()\n", + "\n", + "def calibrate_model(model, model_name, data_loader, num_calib_batch, calibrator, hist_percentile, out_dir):\n", + " \"\"\"\n", + " Feed data to the network and calibrate.\n", + " Arguments:\n", + " model: classification model\n", + " model_name: name to use when creating state files\n", + " data_loader: calibration data set\n", + " num_calib_batch: amount of calibration passes to perform\n", + " calibrator: type of calibration to use (max/histogram)\n", + " hist_percentile: percentiles to be used for historgram calibration\n", + " out_dir: dir to save state files in\n", + " \"\"\"\n", + "\n", + " if num_calib_batch > 0:\n", + " print(\"Calibrating model\")\n", + " with torch.no_grad():\n", + " collect_stats(model, data_loader, num_calib_batch)\n", + "\n", + " if not calibrator == \"histogram\":\n", + " compute_amax(model, method=\"max\")\n", + " calib_output = os.path.join(\n", + " out_dir,\n", + " F\"{model_name}-max-{num_calib_batch*data_loader.batch_size}.pth\")\n", + " torch.save(model.state_dict(), calib_output)\n", + " else:\n", + " for percentile in hist_percentile:\n", + " print(F\"{percentile} percentile calibration\")\n", + " compute_amax(model, method=\"percentile\")\n", + " calib_output = os.path.join(\n", + " out_dir,\n", + " F\"{model_name}-percentile-{percentile}-{num_calib_batch*data_loader.batch_size}.pth\")\n", + " torch.save(model.state_dict(), calib_output)\n", + "\n", + " for method in [\"mse\", \"entropy\"]:\n", + " print(F\"{method} calibration\")\n", + " compute_amax(model, method=method)\n", + " calib_output = os.path.join(\n", + " out_dir,\n", + " F\"{model_name}-{method}-{num_calib_batch*data_loader.batch_size}.pth\")\n", + " torch.save(model.state_dict(), calib_output)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "d8df756c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Calibrating model\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 32/32 [00:00<00:00, 85.05it/s]\n", + "W0831 15:32:46.956144 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator\n", + "W0831 15:32:46.957227 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator\n", + "W0831 15:32:46.958076 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator\n", + "W0831 15:32:46.958884 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator\n", + "W0831 15:32:46.959697 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator\n", + "W0831 15:32:46.960512 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator\n", + "W0831 15:32:46.961301 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator\n", + "W0831 15:32:46.962079 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator\n", + "W0831 15:32:46.962872 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator\n", + "W0831 15:32:46.963665 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator\n", + "W0831 15:32:46.964508 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator\n", + "W0831 15:32:46.965338 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator\n", + "W0831 15:32:46.966276 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator\n", + "W0831 15:32:46.967190 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator\n", + "W0831 15:32:46.967864 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator\n", + "W0831 15:32:46.968530 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator\n", + "W0831 15:32:46.969168 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator\n", + "W0831 15:32:46.969751 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator\n", + "W0831 15:32:46.970463 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator\n", + "W0831 15:32:46.971141 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator\n", + "W0831 15:32:46.971790 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator\n", + "W0831 15:32:46.972402 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator\n", + "W0831 15:32:46.973017 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator\n", + "W0831 15:32:46.973696 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator\n", + "W0831 15:32:46.974347 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator\n", + "W0831 15:32:46.974952 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator\n", + "W0831 15:32:46.975592 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator\n", + "W0831 15:32:46.976269 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator\n", + "W0831 15:32:46.976892 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator\n", + "W0831 15:32:46.977430 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator\n", + "W0831 15:32:46.977965 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator\n", + "W0831 15:32:46.978480 140586428176192 tensor_quantizer.py:173] Disable MaxCalibrator\n", + "W0831 15:32:46.979063 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).\n", + "W0831 15:32:46.979588 140586428176192 tensor_quantizer.py:239] Call .cuda() if running on GPU after loading calibrated amax.\n", + "W0831 15:32:46.980288 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([64, 1, 1, 1]).\n", + "W0831 15:32:46.987690 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).\n", + "W0831 15:32:57.002152 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([64, 1, 1, 1]).\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "features.0._input_quantizer : TensorQuantizer(8bit narrow fake per-tensor amax=2.7537 calibrator=MaxCalibrator scale=1.0 quant)\n", + "features.0._weight_quantizer : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0287, 4.4272](64) calibrator=MaxCalibrator scale=1.0 quant)\n", + "features.3._input_quantizer : TensorQuantizer(8bit narrow fake per-tensor amax=30.1997 calibrator=MaxCalibrator scale=1.0 quant)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "W0831 15:32:57.006651 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).\n", + "W0831 15:32:57.009306 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([128, 1, 1, 1]).\n", + "W0831 15:32:57.011739 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).\n", + "W0831 15:32:57.014180 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([128, 1, 1, 1]).\n", + "W0831 15:32:57.016433 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).\n", + "W0831 15:32:57.018157 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([256, 1, 1, 1]).\n", + "W0831 15:32:57.019830 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).\n", + "W0831 15:32:57.021619 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([256, 1, 1, 1]).\n", + "W0831 15:32:57.023381 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).\n", + "W0831 15:32:57.024606 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([256, 1, 1, 1]).\n", + "W0831 15:32:57.026464 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).\n", + "W0831 15:32:57.027716 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([512, 1, 1, 1]).\n", + "W0831 15:32:57.029010 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).\n", + "W0831 15:32:57.030247 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([512, 1, 1, 1]).\n", + "W0831 15:32:57.031455 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).\n", + "W0831 15:32:57.032716 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([512, 1, 1, 1]).\n", + "W0831 15:32:57.034027 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).\n", + "W0831 15:32:57.035287 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([512, 1, 1, 1]).\n", + "W0831 15:32:57.036572 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).\n", + "W0831 15:32:57.037535 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([512, 1, 1, 1]).\n", + "W0831 15:32:57.038545 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).\n", + "W0831 15:32:57.039479 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([512, 1, 1, 1]).\n", + "W0831 15:32:57.040493 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).\n", + "W0831 15:32:57.041564 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([4096, 1]).\n", + "W0831 15:32:57.042597 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).\n", + "W0831 15:32:57.043280 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([4096, 1]).\n", + "W0831 15:32:57.044521 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).\n", + "W0831 15:32:57.045206 140586428176192 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([10, 1]).\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "features.3._weight_quantizer : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0137, 2.2506](64) calibrator=MaxCalibrator scale=1.0 quant)\n", + "features.7._input_quantizer : TensorQuantizer(8bit narrow fake per-tensor amax=16.2026 calibrator=MaxCalibrator scale=1.0 quant)\n", + "features.7._weight_quantizer : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0602, 1.3986](128) calibrator=MaxCalibrator scale=1.0 quant)\n", + "features.10._input_quantizer : TensorQuantizer(8bit narrow fake per-tensor amax=9.1012 calibrator=MaxCalibrator scale=1.0 quant)\n", + "features.10._weight_quantizer : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0841, 0.9074](128) calibrator=MaxCalibrator scale=1.0 quant)\n", + "features.14._input_quantizer : TensorQuantizer(8bit narrow fake per-tensor amax=10.0201 calibrator=MaxCalibrator scale=1.0 quant)\n", + "features.14._weight_quantizer : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0921, 0.7349](256) calibrator=MaxCalibrator scale=1.0 quant)\n", + "features.17._input_quantizer : TensorQuantizer(8bit narrow fake per-tensor amax=7.0232 calibrator=MaxCalibrator scale=1.0 quant)\n", + "features.17._weight_quantizer : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0406, 0.5232](256) calibrator=MaxCalibrator scale=1.0 quant)\n", + "features.20._input_quantizer : TensorQuantizer(8bit narrow fake per-tensor amax=8.3654 calibrator=MaxCalibrator scale=1.0 quant)\n", + "features.20._weight_quantizer : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0346, 0.4240](256) calibrator=MaxCalibrator scale=1.0 quant)\n", + "features.24._input_quantizer : TensorQuantizer(8bit narrow fake per-tensor amax=7.5746 calibrator=MaxCalibrator scale=1.0 quant)\n", + "features.24._weight_quantizer : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0218, 0.2763](512) calibrator=MaxCalibrator scale=1.0 quant)\n", + "features.27._input_quantizer : TensorQuantizer(8bit narrow fake per-tensor amax=4.8754 calibrator=MaxCalibrator scale=1.0 quant)\n", + "features.27._weight_quantizer : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0163, 0.1819](512) calibrator=MaxCalibrator scale=1.0 quant)\n", + "features.30._input_quantizer : TensorQuantizer(8bit narrow fake per-tensor amax=3.7100 calibrator=MaxCalibrator scale=1.0 quant)\n", + "features.30._weight_quantizer : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0113, 0.1578](512) calibrator=MaxCalibrator scale=1.0 quant)\n", + "features.34._input_quantizer : TensorQuantizer(8bit narrow fake per-tensor amax=3.2465 calibrator=MaxCalibrator scale=1.0 quant)\n", + "features.34._weight_quantizer : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0137, 0.1480](512) calibrator=MaxCalibrator scale=1.0 quant)\n", + "features.37._input_quantizer : TensorQuantizer(8bit narrow fake per-tensor amax=2.3264 calibrator=MaxCalibrator scale=1.0 quant)\n", + "features.37._weight_quantizer : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0122, 0.2957](512) calibrator=MaxCalibrator scale=1.0 quant)\n", + "features.40._input_quantizer : TensorQuantizer(8bit narrow fake per-tensor amax=3.4793 calibrator=MaxCalibrator scale=1.0 quant)\n", + "features.40._weight_quantizer : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0023, 0.6918](512) calibrator=MaxCalibrator scale=1.0 quant)\n", + "classifier.0._input_quantizer : TensorQuantizer(8bit narrow fake per-tensor amax=7.0113 calibrator=MaxCalibrator scale=1.0 quant)\n", + "classifier.0._weight_quantizer : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0027, 0.8358](4096) calibrator=MaxCalibrator scale=1.0 quant)\n", + "classifier.3._input_quantizer : TensorQuantizer(8bit narrow fake per-tensor amax=7.8033 calibrator=MaxCalibrator scale=1.0 quant)\n", + "classifier.3._weight_quantizer : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0024, 0.4038](4096) calibrator=MaxCalibrator scale=1.0 quant)\n", + "classifier.6._input_quantizer : TensorQuantizer(8bit narrow fake per-tensor amax=8.7469 calibrator=MaxCalibrator scale=1.0 quant)\n", + "classifier.6._weight_quantizer : TensorQuantizer(8bit narrow fake axis=0 amax=[0.3125, 0.4321](10) calibrator=MaxCalibrator scale=1.0 quant)\n" + ] + } + ], + "source": [ + "#Calibrate the model using max calibration technique.\n", + "with torch.no_grad():\n", + " calibrate_model(\n", + " model=qat_model,\n", + " model_name=\"vgg16\",\n", + " data_loader=training_dataloader,\n", + " num_calib_batch=32,\n", + " calibrator=\"max\",\n", + " hist_percentile=[99.9, 99.99, 99.999, 99.9999],\n", + " out_dir=\"./\")" + ] + }, + { + "cell_type": "markdown", + "id": "b00dcbf6", + "metadata": {}, + "source": [ + "\n", + "## 6. Quantization Aware Training" + ] + }, + { + "cell_type": "markdown", + "id": "f2e6fba8", + "metadata": {}, + "source": [ + "In this phase, we finetune the model weights and leave the quantizer node values frozen. The dynamic ranges for each layer obtained from the calibration are kept constant while the weights of the model are finetuned to be close to the accuracy of original FP32 model (model without quantizer nodes) is preserved. Usually the finetuning of QAT model should be quick compared to the full training of the original model. Use QAT to fine-tune for around 10% of the original training schedule with an annealing learning-rate. Please refer to Achieving FP32 Accuracy for INT8 Inference Using Quantization Aware Training with NVIDIA TensorRT for detailed recommendations. For this VGG model, it is enough to finetune for 1 epoch to get acceptable accuracy. \n", + "During finetuning with QAT, the quantization is applied as a composition of `max`, `clamp`, `round` and `mul` ops. \n", + "```\n", + "# amax is absolute maximum value for an input\n", + "# The upper bound for integer quantization (127 for int8)\n", + "max_bound = torch.tensor((2.0**(num_bits - 1 + int(unsigned))) - 1.0, device=amax.device)\n", + "scale = max_bound / amax\n", + "outputs = torch.clamp((inputs * scale).round_(), min_bound, max_bound)\n", + "```\n", + "tensor_quant function in `pytorch_quantization` toolkit is responsible for the above tensor quantization. Usually, per channel quantization is recommended for weights, while per tensor quantization is recommended for activations in a network.\n", + "During inference, we use `torch.fake_quantize_per_tensor_affine` and `torch.fake_quantize_per_channel_affine` to perform quantization as this is easier to convert into corresponding TensorRT operators. Please refer to next sections for more details on how these operators are exported in torchscript and converted in TRTorch." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "87af7b92", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Updating learning rate: 0.1\n", + "Epoch: [ 1 / 1] LR: 0.100000\n", + "Batch: [ 500 | 1563] loss: 2.694\n", + "Batch: [ 1000 | 1563] loss: 2.682\n", + "Batch: [ 1500 | 1563] loss: 2.624\n", + "Test Loss: 0.03277 Test Acc: 83.58%\n", + "Checkpoint saved\n" + ] + } + ], + "source": [ + "# Finetune the QAT model for 1 epoch\n", + "num_epochs=1\n", + "for epoch in range(num_epochs):\n", + " adjust_lr(opt, epoch)\n", + " print('Epoch: [%5d / %5d] LR: %f' % (epoch + 1, num_epochs, state[\"lr\"]))\n", + "\n", + " train(qat_model, training_dataloader, crit, opt, epoch)\n", + " test_loss, test_acc = test(qat_model, testing_dataloader, crit, epoch)\n", + "\n", + " print(\"Test Loss: {:.5f} Test Acc: {:.2f}%\".format(test_loss, 100 * test_acc))\n", + " \n", + "save_checkpoint({'epoch': epoch + 1,\n", + " 'model_state_dict': qat_model.state_dict(),\n", + " 'acc': test_acc,\n", + " 'opt_state_dict': opt.state_dict(),\n", + " 'state': state},\n", + " ckpt_path=\"vgg16_qat_ckpt\")" + ] + }, + { + "cell_type": "markdown", + "id": "f697ae8b", + "metadata": {}, + "source": [ + "\n", + "## 7. Export to Torchscript\n", + "Export the model to Torch script. Trace the model and convert it into torchscript for deployment. To learn more about Torchscript, please refer to https://pytorch.org/docs/stable/jit.html. Setting `quant_nn.TensorQuantizer.use_fb_fake_quant = True` enables the QAT model to use `torch.fake_quantize_per_tensor_affine` and `torch.fake_quantize_per_channel_affine` operators instead of `tensor_quant` function to export quantization operators. In torchscript, they are represented as `aten::fake_quantize_per_tensor_affine` and `aten::fake_quantize_per_channel_affine`. " + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "b46906c8", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "E0831 15:41:34.662368 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.664751 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.671072 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.671867 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.683352 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.684193 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.687814 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.688531 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.698150 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.698921 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.702409 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.702994 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.711167 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.711931 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.714900 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.715603 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.725254 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.725864 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.728618 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.729140 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.736662 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.737521 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.739989 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.740708 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.748396 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.749184 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.751592 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.752305 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.764246 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.764994 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.767470 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.768118 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.775590 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.776468 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.778920 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.779547 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.787922 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.788623 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.791333 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.793220 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.802763 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.803504 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.805943 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.806617 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.814899 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.815649 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.818024 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.818692 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.826974 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.827722 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.830084 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.830769 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.844441 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.845136 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.847555 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.848293 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.856972 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.857702 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.860140 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.860877 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.868146 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.868999 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.872753 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.873387 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.931684 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.932640 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.935498 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.936259 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.944115 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.944886 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.947971 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.949408 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.958851 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.959626 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "E0831 15:41:34.962537 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.963227 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.970601 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.971469 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.974947 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.975533 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.985072 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.985844 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.988213 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.988955 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.997645 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:34.998368 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.001345 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.001920 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.009888 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.010627 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.013032 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.013727 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.022683 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.023485 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.025832 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.026518 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.033935 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.034775 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.039378 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.040091 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.047529 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.048348 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.051363 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.051893 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.060786 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.061613 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.065534 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.066100 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.073963 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.074629 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.077306 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.077896 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.085539 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.086258 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.089163 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.089860 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.103728 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.104618 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.107046 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.107893 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.116841 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.117565 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.120490 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.121185 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.128972 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.129700 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.132617 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n", + "E0831 15:41:35.133241 140586428176192 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n" + ] + } + ], + "source": [ + "quant_nn.TensorQuantizer.use_fb_fake_quant = True\n", + "with torch.no_grad():\n", + " data = iter(testing_dataloader)\n", + " images, _ = data.next()\n", + " jit_model = torch.jit.trace(qat_model, images.to(\"cuda\"))\n", + " torch.jit.save(jit_model, \"trained_vgg16_qat.jit.pt\")" + ] + }, + { + "cell_type": "markdown", + "id": "f5986576", + "metadata": {}, + "source": [ + "\n", + "## 8. Inference using TRTorch\n", + "In this phase, we run the exported torchscript graph of VGG QAT using TRTorch. TRTorch is a Pytorch-TensorRT compiler which converts Torchscript graphs into TensorRT. TensorRT 8.0 supports inference of quantization aware trained models and introduces new APIs; `QuantizeLayer` and `DequantizeLayer`. We can observe the entire VGG QAT graph quantization nodes from the debug log of TRTorch. To enable debug logging, you can set `trtorch.logging.set_reportable_log_level(trtorch.logging.Level.Debug)`. For example, `QuantConv2d` layer from `pytorch_quantization` toolkit is represented as follows in Torchscript\n", + "```\n", + "%quant_input : Tensor = aten::fake_quantize_per_tensor_affine(%x, %636, %637, %638, %639)\n", + "%quant_weight : Tensor = aten::fake_quantize_per_channel_affine(%394, %640, %641, %637, %638, %639)\n", + "%input.2 : Tensor = aten::_convolution(%quant_input, %quant_weight, %395, %687, %688, %689, %643, %690, %642, %643, %643, %644, %644)\n", + "```\n", + "`aten::fake_quantize_per_*_affine` is converted into `QuantizeLayer` + `DequantizeLayer` in TRTorch internally. Please refer to quantization op converters in TRTorch." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "1629b222", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "VGG QAT accuracy using TensorRT: 83.59%\n" + ] + } + ], + "source": [ + "qat_model = torch.jit.load(\"trained_vgg16_qat.jit.pt\").eval()\n", + "\n", + "compile_spec = {\"inputs\": [trtorch.Input([16, 3, 32, 32])],\n", + " \"op_precision\": torch.int8,\n", + " }\n", + "trt_mod = trtorch.compile(qat_model, compile_spec)\n", + "\n", + "test_loss, test_acc = test(trt_mod, testing_dataloader, crit, 0)\n", + "print(\"VGG QAT accuracy using TensorRT: {:.2f}%\".format(100 * test_acc))" + ] + }, + { + "cell_type": "markdown", + "id": "4e9ba8b5", + "metadata": {}, + "source": [ + "\n", + "## 9. References\n", + "* Very Deep Convolution Networks for large scale Image Recognition\n", + "* Achieving FP32 Accuracy for INT8 Inference Using Quantization Aware Training with NVIDIA TensorRT\n", + "* QAT workflow for VGG16\n", + "* Deploying VGG QAT model in C++ using TRTorch\n", + "* Pytorch-quantization toolkit from NVIDIA\n", + "* Pytorch quantization toolkit userguide\n", + "* Quantization basics" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}