From 92c81dcc8ca8636dd1bb203330236a0f0fadb9ce Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 13 Aug 2022 16:38:11 +0200 Subject: [PATCH] GFLOPs computation fix for classification models (#8954) * GFLOPs computation fix for classification models Improved robustness in reading input channel count * Update torch_utils.py * Update torch_utils.py --- utils/torch_utils.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 1097458ae45..beb81442912 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -199,12 +199,11 @@ def sparsity(model): def prune(model, amount=0.3): # Prune model to requested global sparsity import torch.nn.utils.prune as prune - print('Pruning model... ', end='') for name, m in model.named_modules(): if isinstance(m, nn.Conv2d): prune.l1_unstructured(m, name='weight', amount=amount) # prune prune.remove(m, 'weight') # make permanent - print(' %.3g global sparsity' % sparsity(model)) + LOGGER.info(f'Model pruned to {sparsity(model):.3g} global sparsity') def fuse_conv_and_bn(conv, bn): @@ -230,7 +229,7 @@ def fuse_conv_and_bn(conv, bn): return fusedconv -def model_info(model, verbose=False, img_size=640): +def model_info(model, verbose=False, imgsz=640): # Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320] n_p = sum(x.numel() for x in model.parameters()) # number parameters n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients @@ -242,12 +241,12 @@ def model_info(model, verbose=False, img_size=640): (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std())) try: # FLOPs - from thop import profile - stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 - img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input - flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPs - img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float - fs = ', %.1f GFLOPs' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 GFLOPs + p = next(model.parameters()) + stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 # max stride + im = torch.zeros((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format + flops = thop.profile(deepcopy(model), inputs=(im,), verbose=False)[0] / 1E9 * 2 # stride GFLOPs + imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float + fs = f', {flops * imgsz[0] / stride * imgsz[1] / stride:.1f} GFLOPs' # 640x640 GFLOPs except Exception: fs = ''