-
-
Notifications
You must be signed in to change notification settings - Fork 16.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* New CSV Logger * cleanup * move batch plots into Logger * rename comment * Remove total loss from progress bar * mloss :-1 bug fix * Update plot_results() * Update plot_results() * plot_results bug fix
- Loading branch information
1 parent
3764277
commit 96e36a7
Showing
6 changed files
with
68 additions
and
109 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,7 +12,6 @@ | |
import time | ||
from copy import deepcopy | ||
from pathlib import Path | ||
from threading import Thread | ||
|
||
import math | ||
import numpy as np | ||
|
@@ -38,7 +37,7 @@ | |
check_requirements, print_mutation, set_logging, one_cycle, colorstr | ||
from utils.google_utils import attempt_download | ||
from utils.loss import ComputeLoss | ||
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution | ||
from utils.plots import plot_labels, plot_evolution | ||
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, de_parallel | ||
from utils.loggers.wandb.wandb_utils import check_wandb_resume | ||
from utils.metrics import fitness | ||
|
@@ -61,7 +60,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary | |
# Directories | ||
w = save_dir / 'weights' # weights dir | ||
w.mkdir(parents=True, exist_ok=True) # make dir | ||
last, best, results_file = w / 'last.pt', w / 'best.pt', save_dir / 'results.txt' | ||
last, best = w / 'last.pt', w / 'best.pt' | ||
|
||
# Hyperparameters | ||
if isinstance(hyp, str): | ||
|
@@ -88,7 +87,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary | |
|
||
# Loggers | ||
if RANK in [-1, 0]: | ||
loggers = Loggers(save_dir, results_file, weights, opt, hyp, data_dict, LOGGER).start() # loggers dict | ||
loggers = Loggers(save_dir, weights, opt, hyp, data_dict, LOGGER).start() # loggers dict | ||
if loggers.wandb and resume: | ||
weights, epochs, hyp, data_dict = opt.weights, opt.epochs, opt.hyp, loggers.wandb.data_dict | ||
|
||
|
@@ -167,10 +166,6 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary | |
ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) | ||
ema.updates = ckpt['updates'] | ||
|
||
# Results | ||
if ckpt.get('training_results') is not None: | ||
results_file.write_text(ckpt['training_results']) # write results.txt | ||
|
||
# Epochs | ||
start_epoch = ckpt['epoch'] + 1 | ||
if resume: | ||
|
@@ -275,11 +270,11 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary | |
# b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs) | ||
# dataset.mosaic_border = [b - imgsz, -b] # height, width borders | ||
|
||
mloss = torch.zeros(4, device=device) # mean losses | ||
mloss = torch.zeros(3, device=device) # mean losses | ||
if RANK != -1: | ||
train_loader.sampler.set_epoch(epoch) | ||
pbar = enumerate(train_loader) | ||
LOGGER.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'total', 'labels', 'img_size')) | ||
LOGGER.info(('\n' + '%10s' * 7) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'labels', 'img_size')) | ||
if RANK in [-1, 0]: | ||
pbar = tqdm(pbar, total=nb) # progress bar | ||
optimizer.zero_grad() | ||
|
@@ -327,20 +322,13 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary | |
ema.update(model) | ||
last_opt_step = ni | ||
|
||
# Log | ||
if RANK in [-1, 0]: | ||
mloss = (mloss * i + loss_items) / (i + 1) # update mean losses | ||
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB) | ||
s = ('%10s' * 2 + '%10.4g' * 6) % ( | ||
f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]) | ||
pbar.set_description(s) | ||
|
||
# Plot | ||
if plots: | ||
if ni < 3: | ||
f = save_dir / f'train_batch{ni}.jpg' # filename | ||
Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start() | ||
loggers.on_train_batch_end(ni, model, imgs) | ||
pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % ( | ||
f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1])) | ||
loggers.on_train_batch_end(ni, model, imgs, targets, paths, plots) | ||
|
||
# end batch ------------------------------------------------------------------------------------------------ | ||
|
||
|
@@ -371,13 +359,12 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary | |
fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, [email protected], [email protected]] | ||
if fi > best_fitness: | ||
best_fitness = fi | ||
loggers.on_train_val_end(mloss, results, lr, epoch, s, best_fitness, fi) | ||
loggers.on_train_val_end(mloss, results, lr, epoch, best_fitness, fi) | ||
|
||
# Save model | ||
if (not nosave) or (final_epoch and not evolve): # if save | ||
ckpt = {'epoch': epoch, | ||
'best_fitness': best_fitness, | ||
'training_results': results_file.read_text(), | ||
'model': deepcopy(de_parallel(model)).half(), | ||
'ema': deepcopy(ema.ema).half(), | ||
'updates': ema.updates, | ||
|
@@ -395,9 +382,6 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary | |
# end training ----------------------------------------------------------------------------------------------------- | ||
if RANK in [-1, 0]: | ||
LOGGER.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n') | ||
if plots: | ||
plot_results(save_dir=save_dir) # save as results.png | ||
|
||
if not evolve: | ||
if is_coco: # COCO dataset | ||
for m in [last, best] if best.exists() else [last]: # speed, mAP tests | ||
|
@@ -411,13 +395,11 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary | |
save_dir=save_dir, | ||
save_json=True, | ||
plots=False) | ||
|
||
# Strip optimizers | ||
for f in last, best: | ||
if f.exists(): | ||
strip_optimizer(f) # strip optimizers | ||
|
||
loggers.on_train_end(last, best) | ||
loggers.on_train_end(last, best, plots) | ||
|
||
torch.cuda.empty_cache() | ||
return results | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,5 @@ | ||
# Plotting utils | ||
|
||
import glob | ||
import os | ||
from copy import copy | ||
from pathlib import Path | ||
|
||
|
@@ -387,63 +385,29 @@ def profile_idetection(start=0, stop=0, labels=(), save_dir=''): | |
plt.savefig(Path(save_dir) / 'idetection_profile.png', dpi=200) | ||
|
||
|
||
def plot_results_overlay(start=0, stop=0): # from utils.plots import *; plot_results_overlay() | ||
# Plot training 'results*.txt', overlaying train and val losses | ||
s = ['train', 'train', 'train', 'Precision', '[email protected]', 'val', 'val', 'val', 'Recall', '[email protected]:0.95'] # legends | ||
t = ['Box', 'Objectness', 'Classification', 'P-R', 'mAP-F1'] # titles | ||
for f in sorted(glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt')): | ||
results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T | ||
n = results.shape[1] # number of rows | ||
x = range(start, min(stop, n) if stop else n) | ||
fig, ax = plt.subplots(1, 5, figsize=(14, 3.5), tight_layout=True) | ||
ax = ax.ravel() | ||
for i in range(5): | ||
for j in [i, i + 5]: | ||
y = results[j, x] | ||
ax[i].plot(x, y, marker='.', label=s[j]) | ||
# y_smooth = butter_lowpass_filtfilt(y) | ||
# ax[i].plot(x, np.gradient(y_smooth), marker='.', label=s[j]) | ||
|
||
ax[i].set_title(t[i]) | ||
ax[i].legend() | ||
ax[i].set_ylabel(f) if i == 0 else None # add filename | ||
fig.savefig(f.replace('.txt', '.png'), dpi=200) | ||
|
||
|
||
def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''): | ||
# Plot training 'results*.txt'. from utils.plots import *; plot_results(save_dir='runs/train/exp') | ||
def plot_results(file='', dir=''): | ||
# Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv') | ||
save_dir = Path(file).parent if file else Path(dir) | ||
fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True) | ||
ax = ax.ravel() | ||
s = ['Box', 'Objectness', 'Classification', 'Precision', 'Recall', | ||
'val Box', 'val Objectness', 'val Classification', '[email protected]', '[email protected]:0.95'] | ||
if bucket: | ||
# files = ['https://storage.googleapis.com/%s/results%g.txt' % (bucket, x) for x in id] | ||
files = ['results%g.txt' % x for x in id] | ||
c = ('gsutil cp ' + '%s ' * len(files) + '.') % tuple('gs://%s/results%g.txt' % (bucket, x) for x in id) | ||
os.system(c) | ||
else: | ||
files = list(Path(save_dir).glob('results*.txt')) | ||
assert len(files), 'No results.txt files found in %s, nothing to plot.' % os.path.abspath(save_dir) | ||
files = list(save_dir.glob('results*.csv')) | ||
assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.' | ||
for fi, f in enumerate(files): | ||
try: | ||
results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T | ||
n = results.shape[1] # number of rows | ||
x = range(start, min(stop, n) if stop else n) | ||
for i in range(10): | ||
y = results[i, x] | ||
if i in [0, 1, 2, 5, 6, 7]: | ||
y[y == 0] = np.nan # don't show zero loss values | ||
# y /= y[0] # normalize | ||
label = labels[fi] if len(labels) else f.stem | ||
ax[i].plot(x, y, marker='.', label=label, linewidth=2, markersize=8) | ||
ax[i].set_title(s[i]) | ||
# if i in [5, 6, 7]: # share train and val loss y axes | ||
data = pd.read_csv(f) | ||
s = [x.strip() for x in data.columns] | ||
x = data.values[:, 0] | ||
for i, j in enumerate([1, 2, 3, 4, 5, 8, 9, 10, 6, 7]): | ||
y = data.values[:, j] | ||
# y[y == 0] = np.nan # don't show zero values | ||
ax[i].plot(x, y, marker='.', label=f.stem, linewidth=2, markersize=8) | ||
ax[i].set_title(s[j], fontsize=12) | ||
# if j in [8, 9, 10]: # share train and val loss y axes | ||
# ax[i].get_shared_y_axes().join(ax[i], ax[i - 5]) | ||
except Exception as e: | ||
print('Warning: Plotting error for %s; %s' % (f, e)) | ||
|
||
print(f'Warning: Plotting error for {f}: {e}') | ||
ax[1].legend() | ||
fig.savefig(Path(save_dir) / 'results.png', dpi=200) | ||
fig.savefig(save_dir / 'results.png', dpi=200) | ||
|
||
|
||
def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detect/exp')): | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters