Skip to content

Commit

Permalink
docs(ptq): Adding a tutorial on how to use PTQ
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Mar 17, 2021
1 parent 06e796e commit dc4d966
Showing 1 changed file with 70 additions and 5 deletions.
75 changes: 70 additions & 5 deletions docsrc/tutorials/ptq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@ the TensorRT calibrator. With TRTorch we look to leverage existing infrastructur
calibrators easier.

LibTorch provides a ``DataLoader`` and ``Dataset`` API which steamlines preprocessing and batching input data.
This section of the PyTorch documentation has more information https://pytorch.org/tutorials/advanced/cpp_frontend.html#loading-data.
These APIs are exposed via both C++ and Python interface which makes it easier for the end user.
For C++ interface, we use ``torch::Dataset`` and ``torch::data::make_data_loader`` objects to construct and perform pre-processing on datasets.
The equivalent functionality in python interface uses ``torch.utils.data.Dataset`` and ``torch.utils.data.DataLoader``.
This section of the PyTorch documentation has more information https://pytorch.org/tutorials/advanced/cpp_frontend.html#loading-data and https://pytorch.org/tutorials/recipes/recipes/loading_data_recipe.html.
TRTorch uses Dataloaders as the base of a generic calibrator implementation. So you will be able to reuse or quickly
implement a ``torch::Dataset`` for your target domain, place it in a DataLoader and create a INT8 Calibrator
which you can provide to TRTorch to run INT8 Calibration during compliation of your module.

.. _writing_ptq:
.. _writing_ptq_cpp:

How to create your own PTQ application
How to create your own PTQ application in C++
----------------------------------------

Here is an example interface of a ``torch::Dataset`` class for CIFAR10:
Expand Down Expand Up @@ -132,14 +135,76 @@ Then all thats required to setup the module for INT8 calibration is to set the f
auto trt_mod = trtorch::CompileGraph(mod, compile_spec);

If you have an existing Calibrator implementation for TensorRT you may directly set the ``ptq_calibrator`` field with a pointer to your calibrator and it will work as well.

From here not much changes in terms of how to execution works. You are still able to fully use LibTorch as the sole interface for inference. Data should remain
in FP32 precision when it's passed into `trt_mod.forward`. There exists an example application in the TRTorch demo that takes you from training a VGG16 network on
CIFAR10 to deploying in INT8 with TRTorch here: https://github.com/NVIDIA/TRTorch/tree/master/cpp/ptq

.. _writing_ptq_python:

How to create your own PTQ application in Python
----------------------------------------

TRTorch Python API provides an easy and convenient way to use pytorch dataloaders with TensorRT calibrators. ``DataLoaderCalibrator`` class can be used to create
a TensorRT calibrator by providing desired configuration. The following code demonstrates an example on how to use it

.. code-block:: python
testing_dataset = torchvision.datasets.CIFAR10(root='./data',
train=False,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010))
]))
testing_dataloader = torch.utils.data.DataLoader(testing_dataset,
batch_size=1,
shuffle=False,
num_workers=1)
calibrator = trtorch.ptq.DataLoaderCalibrator(testing_dataloader,
cache_file='./calibration.cache',
use_cache=False,
algo_type=trtorch.ptq.CalibrationAlgo.ENTROPY_CALIBRATION_2,
device=torch.device('cuda:0'))
compile_spec = {
"input_shapes": [[1, 3, 32, 32]],
"op_precision": torch.int8,
"calibrator": calibrator,
"device": {
"device_type": trtorch.DeviceType.GPU,
"gpu_id": 0,
"dla_core": 0,
"allow_gpu_fallback": False,
"disable_tf32": False
}
}
trt_mod = trtorch.compile(model, compile_spec)
In the cases where there is a pre-existing calibration cache file that users want to use, ``CacheCalibrator`` can be used without any dataloaders. The following example demonstrates how
to use ``CacheCalibrator`` to use in INT8 mode.

.. code-block:: python
calibrator = trtorch.ptq.CacheCalibrator("./calibration.cache")
compile_settings = {
"input_shapes": [[1, 3, 32, 32]],
"op_precision": torch.int8,
"calibrator": calibrator,
"max_batch_size": 32,
}
trt_mod = trtorch.compile(model, compile_settings)
If you already have an existing calibrator class (implemented directly using TensorRT API), you can directly set the calibrator field to your class which can be very convenient.
For a demo on how PTQ can be performed on a VGG network using TRTorch API, you can refer to https://github.com/NVIDIA/TRTorch/blob/master/tests/py/test_ptq_dataloader_calibrator.py
and https://github.com/NVIDIA/TRTorch/blob/master/tests/py/test_ptq_trt_calibrator.py

Citations
^^^^^^^^^^^

Krizhevsky, A., & Hinton, G. (2009). Learning multiple layers of features from tiny images.

Simonyan, K., & Zisserman, A. (2014). Very deep convolutional networks for large-scale image recognition. arXiv preprint arXiv:1409.1556.
Simonyan, K., & Zisserman, A. (2014). Very deep convolutional networks for large-scale image recognition. arXiv preprint arXiv:1409.1556.

0 comments on commit dc4d966

Please sign in to comment.