From 59868d21f15ac46576ec99cf54e9d6d169ff7b1c Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Wed, 21 Apr 2021 11:17:07 -0700 Subject: [PATCH 1/2] refactor: Fix python linting issues Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- cpp/ptq/training/vgg16/export_ckpt.py | 2 +- cpp/ptq/training/vgg16/main.py | 2 +- py/trtorch/_compile_spec.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/ptq/training/vgg16/export_ckpt.py b/cpp/ptq/training/vgg16/export_ckpt.py index 97860978c9..8c2e679b6b 100644 --- a/cpp/ptq/training/vgg16/export_ckpt.py +++ b/cpp/ptq/training/vgg16/export_ckpt.py @@ -22,7 +22,7 @@ def test(model, dataloader, crit): with torch.no_grad(): for data, labels in dataloader: - data, labels = data.cuda(), labels.cuda(async=True) + data, labels = data.cuda(), labels.cuda(non_blocking=True) out = model(data) loss += crit(out, labels) preds = torch.max(out, 1)[1] diff --git a/cpp/ptq/training/vgg16/main.py b/cpp/ptq/training/vgg16/main.py index bff714a47c..eb6ab96f0b 100644 --- a/cpp/ptq/training/vgg16/main.py +++ b/cpp/ptq/training/vgg16/main.py @@ -141,7 +141,7 @@ def train(model, dataloader, crit, opt, epoch): model.train() running_loss = 0.0 for batch, (data, labels) in enumerate(dataloader): - data, labels = data.cuda(), labels.cuda(async=True) + data, labels = data.cuda(), labels.cuda(non_blocking=True) opt.zero_grad() out = model(data) loss = crit(out, labels) diff --git a/py/trtorch/_compile_spec.py b/py/trtorch/_compile_spec.py index 3f1e135234..6e3aac072d 100644 --- a/py/trtorch/_compile_spec.py +++ b/py/trtorch/_compile_spec.py @@ -176,7 +176,7 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec: if "max_batch_size" in compile_spec: assert type(compile_spec["max_batch_size"]) is int info.max_batch_size = compile_spec["max_batch_size"] - + if "truncate_long_and_double" in compile_spec: assert type(compile_spec["truncate_long_and_double"]) is bool info.truncate_long_and_double = compile_spec["truncate_long_and_double"] From a39dea78756adb0aa2fc15707ce375e07f2b5316 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Wed, 21 Apr 2021 13:41:27 -0700 Subject: [PATCH 2/2] feat(//.github): Linter throws 1 when there needs to be style changes to show failing test to people whose PRs cant get linting inlined Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- .github/scripts/run_cpp_linter.py | 1 + .github/scripts/run_py_linter.py | 1 + 2 files changed, 2 insertions(+) diff --git a/.github/scripts/run_cpp_linter.py b/.github/scripts/run_cpp_linter.py index b64306efe1..6f1e6b1259 100644 --- a/.github/scripts/run_cpp_linter.py +++ b/.github/scripts/run_cpp_linter.py @@ -23,6 +23,7 @@ if output.returncode != 0: comment = '''There are some changes that do not conform to C++ style guidelines:\n ```diff\n{}```'''.format(output.stdout.decode("utf-8")) approval = 'REQUEST_CHANGES' + exit(1) pr.create_review(commit, comment, approval) diff --git a/.github/scripts/run_py_linter.py b/.github/scripts/run_py_linter.py index 3b2ed70e38..e30a607a46 100644 --- a/.github/scripts/run_py_linter.py +++ b/.github/scripts/run_py_linter.py @@ -23,5 +23,6 @@ if output.returncode != 0: comment = '''There are some changes that do not conform to Python style guidelines:\n ```diff\n{}```'''.format(output.stdout.decode("utf-8")) approval = 'REQUEST_CHANGES' + exit(1) pr.create_review(commit, comment, approval)