From bfeed1488d50d72f433194aa9859586f2e0f51cc Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 10 Nov 2021 15:47:38 +0100 Subject: [PATCH] Update train, val `tqdm` to fixed width (#5367) * Update tqdm for fixed width * Update val.py * Update val.py * Try ncols= in train.py * NCOLS * NCOLS * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * bar_format * position 0 leave true * exp0 * auto * auto * Cleanup * Cleanup * Cleanup Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- train.py | 11 +++++------ utils/general.py | 5 +++++ val.py | 5 +++-- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/train.py b/train.py index fedc55d8be5c..4193365d5a09 100644 --- a/train.py +++ b/train.py @@ -5,7 +5,6 @@ Usage: $ python path/to/train.py --data coco128.yaml --weights yolov5s.pt --img 640 """ - import argparse import math import os @@ -40,10 +39,10 @@ from utils.callbacks import Callbacks from utils.datasets import create_dataloader from utils.downloads import attempt_download -from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements, - check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds, - intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods, one_cycle, - print_args, print_mutation, strip_optimizer) +from utils.general import (LOGGER, NCOLS, check_dataset, check_file, check_git_status, check_img_size, + check_requirements, check_suffix, check_yaml, colorstr, get_latest_run, increment_path, + init_seeds, intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods, + one_cycle, print_args, print_mutation, strip_optimizer) from utils.loggers import Loggers from utils.loggers.wandb.wandb_utils import check_wandb_resume from utils.loss import ComputeLoss @@ -289,7 +288,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary pbar = enumerate(train_loader) LOGGER.info(('\n' + '%10s' * 7) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'labels', 'img_size')) if RANK in [-1, 0]: - pbar = tqdm(pbar, total=nb) # progress bar + pbar = tqdm(pbar, total=nb, ncols=NCOLS, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar optimizer.zero_grad() for i, (imgs, targets, paths, _) in pbar: # batch ------------------------------------------------------------- ni = i + nb * epoch # number integrated batches (since train start) diff --git a/utils/general.py b/utils/general.py index 8f59d487edfb..fa56ed49aba8 100755 --- a/utils/general.py +++ b/utils/general.py @@ -11,6 +11,7 @@ import platform import random import re +import shutil import signal import time import urllib @@ -834,3 +835,7 @@ def increment_path(path, exist_ok=False, sep='', mkdir=False): if mkdir: path.mkdir(parents=True, exist_ok=True) # make directory return path + + +# Variables +NCOLS = 0 if is_docker() else shutil.get_terminal_size().columns # terminal window size diff --git a/val.py b/val.py index 2bcbc582a500..62a30ac09d39 100644 --- a/val.py +++ b/val.py @@ -26,7 +26,7 @@ from models.common import DetectMultiBackend from utils.callbacks import Callbacks from utils.datasets import create_dataloader -from utils.general import (LOGGER, box_iou, check_dataset, check_img_size, check_requirements, check_yaml, +from utils.general import (LOGGER, NCOLS, box_iou, check_dataset, check_img_size, check_requirements, check_yaml, coco80_to_coco91_class, colorstr, increment_path, non_max_suppression, print_args, scale_coords, xywh2xyxy, xyxy2xywh) from utils.metrics import ConfusionMatrix, ap_per_class @@ -162,7 +162,8 @@ def run(data, dt, p, r, f1, mp, mr, map50, map = [0.0, 0.0, 0.0], 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 loss = torch.zeros(3, device=device) jdict, stats, ap, ap_class = [], [], [], [] - for batch_i, (im, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)): + pbar = tqdm(dataloader, desc=s, ncols=NCOLS, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar + for batch_i, (im, targets, paths, shapes) in enumerate(pbar): t1 = time_sync() if pt: im = im.to(device, non_blocking=True)