diff --git a/bootstrap/engines/engine.py b/bootstrap/engines/engine.py index 86a74a4..88c63a8 100755 --- a/bootstrap/engines/engine.py +++ b/bootstrap/engines/engine.py @@ -201,7 +201,7 @@ def train_epoch(self, model, dataset, optimizer, epoch, mode='train'): for key, value in out.items(): if torch.is_tensor(value): if value.numel() <= 1: - value = value.item() # get number from a torch scalar + value = value.item() # get number from a torch scalar else: continue if isinstance(value, (list, dict, tuple)):