From b244423b75169ce7b8f33f7508aabfaad2269b22 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 10 Aug 2021 17:41:46 -0700 Subject: [PATCH] fix: Minor fixes to qat scripts Signed-off-by: Dheeraj Peri --- examples/int8/training/vgg16/main.py | 2 +- examples/int8/training/vgg16/train_qat.py | 2 +- py/trtorch/csrc/tensorrt_backend.cpp | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/int8/training/vgg16/main.py b/examples/int8/training/vgg16/main.py index 627688cf9b..6185cc210e 100644 --- a/examples/int8/training/vgg16/main.py +++ b/examples/int8/training/vgg16/main.py @@ -124,7 +124,7 @@ def main(): print("Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc)) - if epoch % 10 == 9: + if epoch % 10 == 9 or epoch==args.epochs-1: save_checkpoint( { 'epoch': epoch + 1, diff --git a/examples/int8/training/vgg16/train_qat.py b/examples/int8/training/vgg16/train_qat.py index 8a770ef118..03a8b908f6 100644 --- a/examples/int8/training/vgg16/train_qat.py +++ b/examples/int8/training/vgg16/train_qat.py @@ -183,7 +183,7 @@ def main(): crit = nn.CrossEntropyLoss() opt = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) - import pdb; pdb.set_trace() + if args.start_from != 0: ckpt_file = args.ckpt_dir + '/ckpt_epoch' + str(args.start_from) + '.pth' print('Loading from checkpoint {}'.format(ckpt_file)) diff --git a/py/trtorch/csrc/tensorrt_backend.cpp b/py/trtorch/csrc/tensorrt_backend.cpp index 75dad26644..503a495956 100644 --- a/py/trtorch/csrc/tensorrt_backend.cpp +++ b/py/trtorch/csrc/tensorrt_backend.cpp @@ -27,7 +27,7 @@ c10::impl::GenericDict TensorRTBackend::compile(c10::IValue mod_val, c10::impl:: mod = core::lowering::LowerModule(mod); auto spec = c10::impl::toTypedDict(method_compile_spec); - lowering::LowerInfo lower_info; + core::lowering::LowerInfo lower_info; for (auto it = spec.begin(), end = spec.end(); it != end; ++it) { const auto& method_name = it->key(); auto method = mod.get_method(method_name);