Skip to content

Commit

Permalink
fix: Minor fixes to qat scripts
Browse files Browse the repository at this point in the history
Signed-off-by: Dheeraj Peri <[email protected]>
  • Loading branch information
peri044 committed Aug 11, 2021
1 parent b7f6d8a commit b244423
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion examples/int8/training/vgg16/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def main():

print("Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc))

if epoch % 10 == 9:
if epoch % 10 == 9 or epoch==args.epochs-1:
save_checkpoint(
{
'epoch': epoch + 1,
Expand Down
2 changes: 1 addition & 1 deletion examples/int8/training/vgg16/train_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def main():

crit = nn.CrossEntropyLoss()
opt = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
import pdb; pdb.set_trace()

if args.start_from != 0:
ckpt_file = args.ckpt_dir + '/ckpt_epoch' + str(args.start_from) + '.pth'
print('Loading from checkpoint {}'.format(ckpt_file))
Expand Down
2 changes: 1 addition & 1 deletion py/trtorch/csrc/tensorrt_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ c10::impl::GenericDict TensorRTBackend::compile(c10::IValue mod_val, c10::impl::
mod = core::lowering::LowerModule(mod);

auto spec = c10::impl::toTypedDict<std::string, at::IValue>(method_compile_spec);
lowering::LowerInfo lower_info;
core::lowering::LowerInfo lower_info;
for (auto it = spec.begin(), end = spec.end(); it != end; ++it) {
const auto& method_name = it->key();
auto method = mod.get_method(method_name);
Expand Down

0 comments on commit b244423

Please sign in to comment.