Skip to content

Commit

Permalink
Merge pull request #1281 from pytorch/data_loader
Browse files Browse the repository at this point in the history
chore: Fix data loader issues and nox file paths
  • Loading branch information
peri044 authored Sep 8, 2022
2 parents f16ac7b + 3b45c80 commit 7142c82
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion examples/int8/training/vgg16/export_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test(model, dataloader, crit):
quant_nn.TensorQuantizer.use_fb_fake_quant = True
with torch.no_grad():
data = iter(testing_dataloader)
images, _ = data.next()
images, _ = next(data)
jit_model = torch.jit.trace(model, images.to("cuda"))
torch.jit.save(jit_model, "trained_vgg16_qat.jit.pt")

Expand Down
6 changes: 3 additions & 3 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,8 @@ def run_trt_compatibility_tests(session):
copy_model(session)
session.chdir(os.path.join(TOP_DIR, "tests/py"))
tests = [
"test_trt_intercompatibility.py",
"test_ptq_trt_calibrator.py",
"integrations/test_trt_intercompatibility.py",
# "ptq/test_ptq_trt_calibrator.py",
]
for test in tests:
if USE_HOST_DEPS:
Expand All @@ -295,7 +295,7 @@ def run_multi_gpu_tests(session):
print("Running multi GPU tests")
session.chdir(os.path.join(TOP_DIR, "tests/py"))
tests = [
"test_multi_gpu.py",
"hw/test_multi_gpu.py",
]
for test in tests:
if USE_HOST_DEPS:
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def get_batch(self, names):
if self.current_batch_idx + self.batch_size > len(self.data_loader.dataset):
return None

batch = self.dataset_iterator.next()
batch = next(self.dataset_iterator)
self.current_batch_idx += self.batch_size
inputs_gpu = []
if isinstance(batch, list):
Expand Down
2 changes: 1 addition & 1 deletion tests/py/ptq/test_ptq_trt_calibrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def get_batch(self, names):
):
return None

batch = self.dataset_iterator.next()
batch = next(self.dataset_iterator)
self.current_batch_idx += self.batch_size
# Treat the first element as input and others as targets.
if isinstance(batch, list):
Expand Down

0 comments on commit 7142c82

Please sign in to comment.