Skip to content

Commit

Permalink
EMA FP32 assert classification bug fix (#9016)
Browse files Browse the repository at this point in the history
* Return EMA float on classification val

* verbose val fix

* EMA check
  • Loading branch information
glenn-jocher authored Aug 18, 2022
1 parent 529aafd commit 20049be
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 11 deletions.
3 changes: 2 additions & 1 deletion classify/val.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def run(
if verbose: # all classes
LOGGER.info(f"{'Class':>24}{'Images':>12}{'top1_acc':>12}{'top5_acc':>12}")
LOGGER.info(f"{'all':>24}{targets.shape[0]:>12}{top1:>12.3g}{top5:>12.3g}")
for i, c in enumerate(model.names):
for i, c in model.names.items():
aci = acc[targets == i]
top1i, top5i = aci.mean(0).tolist()
LOGGER.info(f"{c:>24}{aci.shape[0]:>12}{top1i:>12.3g}{top5i:>12.3g}")
Expand All @@ -127,6 +127,7 @@ def run(
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms post-process per image at shape {shape}' % t)
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")

model.float() # for training
return top1, top5, loss


Expand Down
2 changes: 1 addition & 1 deletion export.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ def parse_opt():
parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold')
parser.add_argument('--include',
nargs='+',
default=['torchscript', 'onnx'],
default=['torchscript'],
help='torchscript, onnx, openvino, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs')
opt = parser.parse_args()
print_args(vars(opt))
Expand Down
10 changes: 7 additions & 3 deletions models/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import torch
import torch.nn as nn

from models.common import Conv
from utils.downloads import attempt_download


Expand Down Expand Up @@ -79,11 +78,16 @@ def attempt_load(weights, device=None, inplace=True, fuse=True):
for w in weights if isinstance(weights, list) else [weights]:
ckpt = torch.load(attempt_download(w), map_location='cpu') # load
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model

# Model compatibility updates
if not hasattr(ckpt, 'stride'):
ckpt.stride = torch.tensor([32.]) # compatibility update for ResNet etc.
ckpt.stride = torch.tensor([32.])
if hasattr(ckpt, 'names') and isinstance(ckpt.names, (list, tuple)):
ckpt.names = dict(enumerate(ckpt.names)) # convert to dict

model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode

# Compatibility updates
# Module compatibility updates
for m in model.modules():
t = type(m)
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model):
Expand Down
3 changes: 1 addition & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
data_dict = data_dict or check_dataset(data) # check if None
train_path, val_path = data_dict['train'], data_dict['val']
nc = 1 if single_cls else int(data_dict['nc']) # number of classes
names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
assert len(names) == nc, f'{len(names)} names found for nc={nc} dataset in {data}' # check
names = {0: 'item'} if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
is_coco = isinstance(val_path, str) and val_path.endswith('coco/val2017.txt') # COCO dataset

# Model
Expand Down
7 changes: 3 additions & 4 deletions utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,8 +408,6 @@ class ModelEMA:
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
# Create EMA
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
# if next(model.parameters()).device.type != 'cpu':
# self.ema.half() # FP16 EMA
self.updates = updates # number of EMA updates
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
for p in self.ema.parameters():
Expand All @@ -423,9 +421,10 @@ def update(self, model):

msd = de_parallel(model).state_dict() # model state_dict
for k, v in self.ema.state_dict().items():
if v.dtype.is_floating_point:
if v.dtype.is_floating_point: # true for FP16 and FP32
v *= d
v += (1 - d) * msd[k].detach()
v += (1 - d) * msd[k]
assert v.dtype == msd[k].dtype == torch.float32, f'EMA {v.dtype} and model {msd[k]} must be updated in FP32'

def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
# Update EMA attributes
Expand Down

0 comments on commit 20049be

Please sign in to comment.