Skip to content

Commit

Permalink
W&B: Restructure code to support the new dataset_check() feature (#4197)
Browse files Browse the repository at this point in the history
* Improve docstrings and run names

* default wandb login prompt with timeout

* return key

* Update api_key check logic

* Properly support zipped dataset feature

* update docstring

* Revert tuorial change

* extend changes to log_dataset

* add run name

* bug fix

* bug fix

* Update comment

* fix import check

* remove unused import

* Hardcore .yaml file extension

* reduce code

* Reformat using pycharm

Co-authored-by: Glenn Jocher <[email protected]>
  • Loading branch information
AyushExel and glenn-jocher authored Jul 28, 2021
1 parent 2683b18 commit e88e8f7
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 40 deletions.
Empty file modified README.md
100755 → 100644
Empty file.
17 changes: 11 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,24 +73,29 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
yaml.safe_dump(hyp, f, sort_keys=False)
with open(save_dir / 'opt.yaml', 'w') as f:
yaml.safe_dump(vars(opt), f, sort_keys=False)
data_dict = None

# Loggers
if RANK in [-1, 0]:
loggers = Loggers(save_dir, weights, opt, hyp, LOGGER).start() # loggers dict
if loggers.wandb:
data_dict = loggers.wandb.data_dict
if resume:
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp


# Config
plots = not evolve # create plots
cuda = device.type != 'cpu'
init_seeds(1 + RANK)
with torch_distributed_zero_first(RANK):
data_dict = check_dataset(data) # check
data_dict = data_dict or check_dataset(data) # check if None
train_path, val_path = data_dict['train'], data_dict['val']
nc = 1 if single_cls else int(data_dict['nc']) # number of classes
names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
assert len(names) == nc, f'{len(names)} names found for nc={nc} dataset in {data}' # check
is_coco = data.endswith('coco.yaml') and nc == 80 # COCO dataset

# Loggers
if RANK in [-1, 0]:
loggers = Loggers(save_dir, weights, opt, hyp, data_dict, LOGGER).start() # loggers dict
if loggers.wandb and resume:
weights, epochs, hyp, data_dict = opt.weights, opt.epochs, opt.hyp, loggers.wandb.data_dict

# Model
pretrained = weights.endswith('.pt')
Expand Down
13 changes: 4 additions & 9 deletions utils/loggers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# YOLOv5 experiment logging utils

import torch
import warnings
from threading import Thread

import torch
from torch.utils.tensorboard import SummaryWriter

from utils.general import colorstr, emojis
Expand All @@ -23,12 +21,11 @@

class Loggers():
# YOLOv5 Loggers class
def __init__(self, save_dir=None, weights=None, opt=None, hyp=None, data_dict=None, logger=None, include=LOGGERS):
def __init__(self, save_dir=None, weights=None, opt=None, hyp=None, logger=None, include=LOGGERS):
self.save_dir = save_dir
self.weights = weights
self.opt = opt
self.hyp = hyp
self.data_dict = data_dict
self.logger = logger # for printing results to console
self.include = include
for k in LOGGERS:
Expand All @@ -38,9 +35,7 @@ def start(self):
self.csv = True # always log to csv

# Message
try:
import wandb
except ImportError:
if not wandb:
prefix = colorstr('Weights & Biases: ')
s = f"{prefix}run 'pip install wandb' to automatically track and visualize YOLOv5 🚀 runs (RECOMMENDED)"
print(emojis(s))
Expand All @@ -57,7 +52,7 @@ def start(self):
assert 'wandb' in self.include and wandb
run_id = torch.load(self.weights).get('wandb_id') if self.opt.resume else None
self.opt.hyp = self.hyp # add hyperparameters
self.wandb = WandbLogger(self.opt, run_id, self.data_dict)
self.wandb = WandbLogger(self.opt, run_id)
except:
self.wandb = None

Expand Down
6 changes: 2 additions & 4 deletions utils/loggers/wandb/log_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import argparse

import yaml

from wandb_utils import WandbLogger
Expand All @@ -8,9 +7,7 @@


def create_dataset_artifact(opt):
with open(opt.data, encoding='ascii', errors='ignore') as f:
data = yaml.safe_load(f) # data dict
logger = WandbLogger(opt, '', None, data, job_type='Dataset Creation') # TODO: return value unused
logger = WandbLogger(opt, None, job_type='Dataset Creation') # TODO: return value unused


if __name__ == '__main__':
Expand All @@ -19,6 +16,7 @@ def create_dataset_artifact(opt):
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
parser.add_argument('--project', type=str, default='YOLOv5', help='name of W&B Project')
parser.add_argument('--entity', default=None, help='W&B entity')
parser.add_argument('--name', type=str, default='log dataset', help='name of W&B run')

opt = parser.parse_args()
opt.resume = False # Explicitly disallow resume check for dataset upload job
Expand Down
3 changes: 1 addition & 2 deletions utils/loggers/wandb/sweep.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import sys
from pathlib import Path

import wandb
from pathlib import Path

FILE = Path(__file__).absolute()
sys.path.append(FILE.parents[2].as_posix()) # add utils/ to path
Expand Down
53 changes: 34 additions & 19 deletions utils/loggers/wandb/wandb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
import logging
import os
import sys
import yaml
from contextlib import contextmanager
from pathlib import Path

import yaml
from tqdm import tqdm

FILE = Path(__file__).absolute()
Expand Down Expand Up @@ -99,7 +98,7 @@ class WandbLogger():
https://docs.wandb.com/guides/integrations/yolov5
"""

def __init__(self, opt, run_id, data_dict, job_type='Training'):
def __init__(self, opt, run_id, job_type='Training'):
"""
- Initialize WandbLogger instance
- Upload dataset if opt.upload_dataset is True
Expand All @@ -108,7 +107,6 @@ def __init__(self, opt, run_id, data_dict, job_type='Training'):
arguments:
opt (namespace) -- Commandline arguments for this run
run_id (str) -- Run ID of W&B run to be resumed
data_dict (Dict) -- Dictionary conataining info about the dataset to be used
job_type (str) -- To set the job_type for this run
"""
Expand All @@ -119,10 +117,11 @@ def __init__(self, opt, run_id, data_dict, job_type='Training'):
self.train_artifact_path, self.val_artifact_path = None, None
self.result_artifact = None
self.val_table, self.result_table = None, None
self.data_dict = data_dict
self.bbox_media_panel_images = []
self.val_table_path_map = None
self.max_imgs_to_log = 16
self.wandb_artifact_data_dict = None
self.data_dict = None
# It's more elegant to stick to 1 wandb.init call, but useful config data is overwritten in the WandbLogger's wandb.init call
if isinstance(opt.resume, str): # checks resume from artifact
if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
Expand All @@ -148,11 +147,23 @@ def __init__(self, opt, run_id, data_dict, job_type='Training'):
if self.wandb_run:
if self.job_type == 'Training':
if not opt.resume:
wandb_data_dict = self.check_and_upload_dataset(opt) if opt.upload_dataset else data_dict
# Info useful for resuming from artifacts
self.wandb_run.config.update({'opt': vars(opt), 'data_dict': wandb_data_dict},
allow_val_change=True)
self.data_dict = self.setup_training(opt, data_dict)
if opt.upload_dataset:
self.wandb_artifact_data_dict = self.check_and_upload_dataset(opt)

elif opt.data.endswith('_wandb.yaml'): # When dataset is W&B artifact
with open(opt.data, encoding='ascii', errors='ignore') as f:
data_dict = yaml.safe_load(f)
self.data_dict = data_dict
else: # Local .yaml dataset file or .zip file
self.data_dict = check_dataset(opt.data)

self.setup_training(opt)
# write data_dict to config. useful for resuming from artifacts
if not self.wandb_artifact_data_dict:
self.wandb_artifact_data_dict = self.data_dict
self.wandb_run.config.update({'data_dict': self.wandb_artifact_data_dict},
allow_val_change=True)

if self.job_type == 'Dataset Creation':
self.data_dict = self.check_and_upload_dataset(opt)

Expand All @@ -167,15 +178,15 @@ def check_and_upload_dataset(self, opt):
Updated dataset info dictionary where local dataset paths are replaced by WAND_ARFACT_PREFIX links.
"""
assert wandb, 'Install wandb to upload dataset'
config_path = self.log_dataset_artifact(check_file(opt.data),
config_path = self.log_dataset_artifact(opt.data,
opt.single_cls,
'YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem)
print("Created dataset config file ", config_path)
with open(config_path, encoding='ascii', errors='ignore') as f:
wandb_data_dict = yaml.safe_load(f)
return wandb_data_dict

def setup_training(self, opt, data_dict):
def setup_training(self, opt):
"""
Setup the necessary processes for training YOLO models:
- Attempt to download model checkpoint and dataset artifacts if opt.resume stats with WANDB_ARTIFACT_PREFIX
Expand All @@ -184,10 +195,7 @@ def setup_training(self, opt, data_dict):
arguments:
opt (namespace) -- commandline arguments for this run
data_dict (Dict) -- Dataset dictionary for this run
returns:
data_dict (Dict) -- contains the updated info about the dataset to be used for training
"""
self.log_dict, self.current_epoch = {}, 0
self.bbox_interval = opt.bbox_interval
Expand All @@ -198,8 +206,10 @@ def setup_training(self, opt, data_dict):
config = self.wandb_run.config
opt.weights, opt.save_period, opt.batch_size, opt.bbox_interval, opt.epochs, opt.hyp = str(
self.weights), config.save_period, config.batch_size, config.bbox_interval, config.epochs, \
config.opt['hyp']
config.hyp
data_dict = dict(self.wandb_run.config.data_dict) # eliminates the need for config file to resume
else:
data_dict = self.data_dict
if self.val_artifact is None: # If --upload_dataset is set, use the existing artifact, don't download
self.train_artifact_path, self.train_artifact = self.download_dataset_artifact(data_dict.get('train'),
opt.artifact_alias)
Expand All @@ -221,7 +231,10 @@ def setup_training(self, opt, data_dict):
self.map_val_table_path()
if opt.bbox_interval == -1:
self.bbox_interval = opt.bbox_interval = (opt.epochs // 10) if opt.epochs > 10 else 1
return data_dict
train_from_artifact = self.train_artifact_path is not None and self.val_artifact_path is not None
# Update the the data_dict to point to local artifacts dir
if train_from_artifact:
self.data_dict = data_dict

def download_dataset_artifact(self, path, alias):
"""
Expand Down Expand Up @@ -299,7 +312,8 @@ def log_dataset_artifact(self, data_file, single_cls, project, overwrite_config=
returns:
the new .yaml file with artifact links. it can be used to start training directly from artifacts
"""
data = check_dataset(data_file) # parse and check
self.data_dict = check_dataset(data_file) # parse and check
data = dict(self.data_dict)
nc, names = (1, ['item']) if single_cls else (int(data['nc']), data['names'])
names = {k: v for k, v in enumerate(names)} # to index dictionary
self.train_artifact = self.create_dataset_table(LoadImagesAndLabels(
Expand All @@ -310,7 +324,8 @@ def log_dataset_artifact(self, data_file, single_cls, project, overwrite_config=
data['train'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'train')
if data.get('val'):
data['val'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'val')
path = data_file if overwrite_config else '_wandb.'.join(data_file.rsplit('.', 1)) # updated data.yaml path
path = Path(data_file).stem
path = (path if overwrite_config else path + '_wandb') + '.yaml' # updated data.yaml path
data.pop('download', None)
data.pop('path', None)
with open(path, 'w') as f:
Expand Down

0 comments on commit e88e8f7

Please sign in to comment.