Skip to content

Commit

Permalink
Merge pull request #439 from NVIDIA/clean_up_ptq_recipe
Browse files Browse the repository at this point in the history
refactor: Fix python linting issues
  • Loading branch information
narendasan authored Apr 21, 2021
2 parents 2af2c11 + a39dea7 commit c3b8583
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 3 deletions.
1 change: 1 addition & 0 deletions .github/scripts/run_cpp_linter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
if output.returncode != 0:
comment = '''There are some changes that do not conform to C++ style guidelines:\n ```diff\n{}```'''.format(output.stdout.decode("utf-8"))
approval = 'REQUEST_CHANGES'
exit(1)

pr.create_review(commit, comment, approval)

Expand Down
1 change: 1 addition & 0 deletions .github/scripts/run_py_linter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,6 @@
if output.returncode != 0:
comment = '''There are some changes that do not conform to Python style guidelines:\n ```diff\n{}```'''.format(output.stdout.decode("utf-8"))
approval = 'REQUEST_CHANGES'
exit(1)

pr.create_review(commit, comment, approval)
2 changes: 1 addition & 1 deletion cpp/ptq/training/vgg16/export_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test(model, dataloader, crit):

with torch.no_grad():
for data, labels in dataloader:
data, labels = data.cuda(), labels.cuda(async=True)
data, labels = data.cuda(), labels.cuda(non_blocking=True)
out = model(data)
loss += crit(out, labels)
preds = torch.max(out, 1)[1]
Expand Down
2 changes: 1 addition & 1 deletion cpp/ptq/training/vgg16/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def train(model, dataloader, crit, opt, epoch):
model.train()
running_loss = 0.0
for batch, (data, labels) in enumerate(dataloader):
data, labels = data.cuda(), labels.cuda(async=True)
data, labels = data.cuda(), labels.cuda(non_blocking=True)
opt.zero_grad()
out = model(data)
loss = crit(out, labels)
Expand Down
2 changes: 1 addition & 1 deletion py/trtorch/_compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
if "max_batch_size" in compile_spec:
assert type(compile_spec["max_batch_size"]) is int
info.max_batch_size = compile_spec["max_batch_size"]

if "truncate_long_and_double" in compile_spec:
assert type(compile_spec["truncate_long_and_double"]) is bool
info.truncate_long_and_double = compile_spec["truncate_long_and_double"]
Expand Down

0 comments on commit c3b8583

Please sign in to comment.