Skip to content

Commit

Permalink
Generalized regression criterion renaming (#1120)
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher authored Oct 11, 2020
1 parent 10c85bf commit 0ada058
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 22 deletions.
2 changes: 1 addition & 1 deletion data/hyp.finetune.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ weight_decay: 0.00036
warmup_epochs: 2.0
warmup_momentum: 0.5
warmup_bias_lr: 0.05
giou: 0.0296
box: 0.0296
cls: 0.243
cls_pw: 0.631
obj: 0.301
Expand Down
2 changes: 1 addition & 1 deletion data/hyp.scratch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ weight_decay: 0.0005 # optimizer weight decay 5e-4
warmup_epochs: 3.0 # warmup epochs (fractions ok)
warmup_momentum: 0.8 # warmup initial momentum
warmup_bias_lr: 0.1 # warmup initial bias lr
giou: 0.05 # box loss gain
box: 0.05 # box loss gain
cls: 0.5 # cls loss gain
cls_pw: 1.0 # cls BCELoss positive_weight
obj: 1.0 # obj loss gain (scale with pixels)
Expand Down
2 changes: 1 addition & 1 deletion sotabench.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def test(data,

# Compute loss
if training: # if model has loss hyperparameters
loss += compute_loss([x.float() for x in train_out], targets, model)[1][:3] # GIoU, obj, cls
loss += compute_loss([x.float() for x in train_out], targets, model)[1][:3] # box, obj, cls

# Run NMS
t = time_synchronized()
Expand Down
2 changes: 1 addition & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def test(data,

# Compute loss
if training: # if model has loss hyperparameters
loss += compute_loss([x.float() for x in train_out], targets, model)[1][:3] # GIoU, obj, cls
loss += compute_loss([x.float() for x in train_out], targets, model)[1][:3] # box, obj, cls

# Run NMS
t = time_synchronized()
Expand Down
18 changes: 9 additions & 9 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def train(hyp, opt, device, tb_writer=None):
hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
model.nc = nc # attach number of classes to model
model.hyp = hyp # attach hyperparameters to model
model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou)
model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou)
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
model.names = names

Expand All @@ -204,7 +204,7 @@ def train(hyp, opt, device, tb_writer=None):
nw = max(round(hyp['warmup_epochs'] * nb), 1e3) # number of warmup iterations, max(3 epochs, 1k iterations)
# nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
maps = np.zeros(nc) # mAP per class
results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, [email protected], val_loss(box, obj, cls)
scheduler.last_epoch = start_epoch - 1 # do not move
scaler = amp.GradScaler(enabled=cuda)
logger.info('Image sizes %g train, %g test\nUsing %g dataloader workers\nLogging results to %s\n'
Expand Down Expand Up @@ -234,7 +234,7 @@ def train(hyp, opt, device, tb_writer=None):
if rank != -1:
dataloader.sampler.set_epoch(epoch)
pbar = enumerate(dataloader)
logger.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'GIoU', 'obj', 'cls', 'total', 'targets', 'img_size'))
logger.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'total', 'targets', 'img_size'))
if rank in [-1, 0]:
pbar = tqdm(pbar, total=nb) # progress bar
optimizer.zero_grad()
Expand All @@ -245,7 +245,7 @@ def train(hyp, opt, device, tb_writer=None):
# Warmup
if ni <= nw:
xi = [0, nw] # x interp
# model.gr = np.interp(ni, xi, [0.0, 1.0]) # giou loss ratio (obj_loss = 1.0 or giou)
# model.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
accumulate = max(1, np.interp(ni, xi, [1, nbs / total_batch_size]).round())
for j, x in enumerate(optimizer.param_groups):
# bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
Expand Down Expand Up @@ -319,21 +319,21 @@ def train(hyp, opt, device, tb_writer=None):

# Write
with open(results_file, 'a') as f:
f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP, F1, test_losses=(GIoU, obj, cls)
f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP@.5, [email protected], val_loss(box, obj, cls)
if len(opt.name) and opt.bucket:
os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name))

# Tensorboard
if tb_writer:
tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss', # train loss
tags = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss
'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
'val/giou_loss', 'val/obj_loss', 'val/cls_loss', # val loss
'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
'x/lr0', 'x/lr1', 'x/lr2'] # params
for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
tb_writer.add_scalar(tag, x, epoch)

# Update best mAP
fi = fitness(np.array(results).reshape(1, -1)) # fitness_i = weighted combination of [P, R, mAP, F1]
fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, [email protected]]
if fi > best_fitness:
best_fitness = fi

Expand Down Expand Up @@ -463,7 +463,7 @@ def train(hyp, opt, device, tb_writer=None):
'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
'giou': (1, 0.02, 0.2), # GIoU loss gain
'box': (1, 0.02, 0.2), # box loss gain
'cls': (1, 0.2, 4.0), # cls loss gain
'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
Expand Down
18 changes: 9 additions & 9 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,11 +509,11 @@ def compute_loss(p, targets, model): # predictions, targets, model
pxy = ps[:, :2].sigmoid() * 2. - 0.5
pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
pbox = torch.cat((pxy, pwh), 1).to(device) # predicted box
giou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True) # giou(prediction, target)
lbox += (1.0 - giou).mean() # giou loss
iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True) # iou(prediction, target)
lbox += (1.0 - iou).mean() # iou loss

# Objectness
tobj[b, a, gj, gi] = (1.0 - model.gr) + model.gr * giou.detach().clamp(0).type(tobj.dtype) # giou ratio
tobj[b, a, gj, gi] = (1.0 - model.gr) + model.gr * iou.detach().clamp(0).type(tobj.dtype) # iou ratio

# Classification
if model.nc > 1: # cls loss (only if multiple classes)
Expand All @@ -528,7 +528,7 @@ def compute_loss(p, targets, model): # predictions, targets, model
lobj += BCEobj(pi[..., 4], tobj) * balance[i] # obj loss

s = 3 / np # output count scaling
lbox *= h['giou'] * s
lbox *= h['box'] * s
lobj *= h['obj'] * s * (1.4 if np == 4 else 1.)
lcls *= h['cls'] * s
bs = tobj.shape[0] # batch size
Expand Down Expand Up @@ -1234,7 +1234,7 @@ def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.general im
def plot_results_overlay(start=0, stop=0): # from utils.general import *; plot_results_overlay()
# Plot training 'results*.txt', overlaying train and val losses
s = ['train', 'train', 'train', 'Precision', '[email protected]', 'val', 'val', 'val', 'Recall', '[email protected]:0.95'] # legends
t = ['GIoU', 'Objectness', 'Classification', 'P-R', 'mAP-F1'] # titles
t = ['Box', 'Objectness', 'Classification', 'P-R', 'mAP-F1'] # titles
for f in sorted(glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt')):
results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
n = results.shape[1] # number of rows
Expand All @@ -1254,13 +1254,13 @@ def plot_results_overlay(start=0, stop=0): # from utils.general import *; plot_
fig.savefig(f.replace('.txt', '.png'), dpi=200)


def plot_results(start=0, stop=0, bucket='', id=(), labels=(),
save_dir=''): # from utils.general import *; plot_results()
def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
# from utils.general import *; plot_results()
# Plot training 'results*.txt' as seen in https://github.com/ultralytics/yolov5#reproduce-our-training
fig, ax = plt.subplots(2, 5, figsize=(12, 6))
ax = ax.ravel()
s = ['GIoU', 'Objectness', 'Classification', 'Precision', 'Recall',
'val GIoU', 'val Objectness', 'val Classification', '[email protected]', '[email protected]:0.95']
s = ['Box', 'Objectness', 'Classification', 'Precision', 'Recall',
'val Box', 'val Objectness', 'val Classification', '[email protected]', '[email protected]:0.95']
if bucket:
# os.system('rm -rf storage.googleapis.com')
# files = ['https://storage.googleapis.com/%s/results%g.txt' % (bucket, x) for x in id]
Expand Down

0 comments on commit 0ada058

Please sign in to comment.