diff --git a/bootstrap/engines/engine.py b/bootstrap/engines/engine.py index de37236..86a74a4 100755 --- a/bootstrap/engines/engine.py +++ b/bootstrap/engines/engine.py @@ -200,8 +200,8 @@ def train_epoch(self, model, dataset, optimizer, epoch, mode='train'): for key, value in out.items(): if torch.is_tensor(value): - if value.dim() <= 1: - value = value.item() # get number from a torch scalar + if value.numel() <= 1: + value = value.item() # get number from a torch scalar else: continue if isinstance(value, (list, dict, tuple)):