diff --git a/py/trtorch/ptq.py b/py/trtorch/ptq.py index 0a8e3a722d..cfdd8a67b8 100644 --- a/py/trtorch/ptq.py +++ b/py/trtorch/ptq.py @@ -26,7 +26,8 @@ def get_batch_size(self): def get_batch(self, names): - if self.current_batch_idx + self.batch_size > self.data_loader.dataset.data.shape[0]: + print("Current batch idx: ", self.current_batch_idx, " Dataset size: ", len(self.data_loader.dataset)) + if self.current_batch_idx + self.batch_size > len(self.data_loader.dataset): return None batch = self.dataset_iterator.next()