From 9dc60618a6dfc845b454389b676de312dbe9fe82 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Thu, 12 Aug 2021 11:30:52 -0700 Subject: [PATCH] fix(qat): Rescale input data for C++ application Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- examples/int8/datasets/cifar10.cpp | 2 +- examples/int8/ptq/main.cpp | 1 + examples/int8/qat/main.cpp | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/int8/datasets/cifar10.cpp b/examples/int8/datasets/cifar10.cpp index 10c52973f8..dff4716dac 100644 --- a/examples/int8/datasets/cifar10.cpp +++ b/examples/int8/datasets/cifar10.cpp @@ -50,7 +50,7 @@ std::pair read_batch(const std::string& path) { labels.push_back(label); auto image_tensor = torch::from_blob(image.data(), {kImageChannels, kImageDim, kImageDim}, torch::TensorOptions().dtype(torch::kU8)) - .to(torch::kF32); + .to(torch::kF32).div(255); images.push_back(image_tensor); } diff --git a/examples/int8/ptq/main.cpp b/examples/int8/ptq/main.cpp index eb96adb98f..752d3a84fe 100644 --- a/examples/int8/ptq/main.cpp +++ b/examples/int8/ptq/main.cpp @@ -140,4 +140,5 @@ int main(int argc, const char* argv[]) { auto trt_runtimes = benchmark_module(trt_mod, dims[0]); print_avg_std_dev("TRT quantized model", trt_runtimes, dims[0][0]); + trt_mod.save("/tmp/ptq_vgg16.trt.ts"); } diff --git a/examples/int8/qat/main.cpp b/examples/int8/qat/main.cpp index 33e5e295bb..50db43ec1e 100644 --- a/examples/int8/qat/main.cpp +++ b/examples/int8/qat/main.cpp @@ -124,5 +124,6 @@ int main(int argc, const char* argv[]) { auto trt_runtimes = benchmark_module(trt_mod, dims[0]); print_avg_std_dev("TRT quantized model", trt_runtimes, dims[0][0]); + trt_mod.save("/tmp/qat_vgg16.trt.ts"); }