Skip to content

Commit

Permalink
fix: Fixed failures for host deps sessions
Browse files Browse the repository at this point in the history
Signed-off-by: Anurag Dixit <[email protected]>
  • Loading branch information
andi4191 committed Mar 2, 2022
1 parent 8580423 commit ec2232f
Showing 1 changed file with 24 additions and 6 deletions.
30 changes: 24 additions & 6 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def train_model(session, use_host_env=False):

session.run_always('python',
'export_ckpt.py',
'vgg16_ckpts/ckpt_epoch25.pth')
'vgg16_ckpts/ckpt_epoch25.pth',
env={'PYTHONPATH': PYT_PATH})
else:
session.run_always('python',
'main.py',
Expand Down Expand Up @@ -146,13 +147,27 @@ def run_accuracy_tests(session, use_host_env=False):
else:
session.run_always("python", test)

def copy_model(session):
model_files = [ 'trained_vgg16.jit.pt',
'trained_vgg16_qat.jit.pt']

for file_name in model_files:
src_file = os.path.join(TOP_DIR, str('examples/int8/training/vgg16/') + file_name)
if os.path.exists(src_file):
session.run_always('cp',
'-rpf',
os.path.join(TOP_DIR, src_file),
os.path.join(TOP_DIR, str('tests/py/') + file_name),
external=True)

def run_int8_accuracy_tests(session, use_host_env=False):
print("Running accuracy tests")
copy_model(session)
session.chdir(os.path.join(TOP_DIR, 'tests/py'))
tests = [
"test_ptq_dataloader.py",
"test_ptq_dataloader_calibrator.py",
"test_ptq_to_backend.py",
"test_qat_trt_accuracy",
"test_qat_trt_accuracy.py",
]
for test in tests:
if use_host_env:
Expand All @@ -162,9 +177,10 @@ def run_int8_accuracy_tests(session, use_host_env=False):

def run_trt_compatibility_tests(session, use_host_env=False):
print("Running TensorRT compatibility tests")
copy_model(session)
session.chdir(os.path.join(TOP_DIR, 'tests/py'))
tests = [
"test_trt_intercompatibilty.py",
"test_trt_intercompatability.py",
"test_ptq_trt_calibrator.py",
]
for test in tests:
Expand Down Expand Up @@ -218,7 +234,7 @@ def run_l1_accuracy_tests(session, use_host_env=False):
install_deps(session)
install_torch_trt(session)
download_models(session, use_host_env)
download_datasets(session, use_host_env)
download_datasets(session)
train_model(session, use_host_env)
run_accuracy_tests(session, use_host_env)
cleanup(session)
Expand All @@ -228,7 +244,7 @@ def run_l1_int8_accuracy_tests(session, use_host_env=False):
install_deps(session)
install_torch_trt(session)
download_models(session, use_host_env)
download_datasets(session, use_host_env)
download_datasets(session)
train_model(session, use_host_env)
finetune_model(session, use_host_env)
run_int8_accuracy_tests(session, use_host_env)
Expand All @@ -239,6 +255,8 @@ def run_l2_trt_compatibility_tests(session, use_host_env=False):
install_deps(session)
install_torch_trt(session)
download_models(session, use_host_env)
download_datasets(session)
train_model(session, use_host_env)
run_trt_compatibility_tests(session, use_host_env)
cleanup(session)

Expand Down

0 comments on commit ec2232f

Please sign in to comment.