Skip to content

Commit

Permalink
[MMSIG-80] Update and refine get_flops.py (open-mmlab#2237)
Browse files Browse the repository at this point in the history
  • Loading branch information
xin-li-67 authored and Tau-J committed Apr 25, 2023
1 parent 1e239b7 commit 912d458
Showing 1 changed file with 71 additions and 44 deletions.
115 changes: 71 additions & 44 deletions tools/analysis_tools/get_flops.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,26 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from functools import partial

import numpy as np
import torch
from mmengine.config import DictAction
from mmengine.logging import MMLogger

from mmpose.apis.inference import init_model

try:
from mmcv.cnn import get_model_complexity_info
from mmengine.analysis import get_model_complexity_info
from mmengine.analysis.print_helper import _format_size
except ImportError:
raise ImportError('Please upgrade mmcv to >0.6.2')
raise ImportError('Please upgrade mmengine >= 0.6.0')


def parse_args():
parser = argparse.ArgumentParser(description='Train a recognizer')
parser = argparse.ArgumentParser(
description='Get complexity information from a model config')
parser.add_argument('config', help='train config file path')
parser.add_argument(
'--device',
default='cuda:0',
help='Device used for model initialization')
'--device', default='cpu', help='Device used for model initialization')
parser.add_argument(
'--cfg-options',
nargs='+',
Expand All @@ -29,28 +30,24 @@ def parse_args():
'in xxx=yyy format will be merged into config file. For example, '
"'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'")
parser.add_argument(
'--shape',
'--input-shape',
type=int,
nargs='+',
default=[256, 192],
help='input image size')
parser.add_argument(
'--input-constructor',
'-c',
type=str,
choices=['none', 'batch'],
default='none',
help='If specified, it takes a callable method that generates '
'input. Otherwise, it will generate a random tensor with '
'input shape to calculate FLOPs.')
parser.add_argument(
'--batch-size', '-b', type=int, default=1, help='input batch size')
'--batch-size',
'-b',
type=int,
default=1,
help='Input batch size. If specified and greater than 1, it takes a '
'callable method that generates a batch input. Otherwise, it will '
'generate a random tensor with input shape to calculate FLOPs.')
parser.add_argument(
'--not-print-per-layer-stat',
'-n',
'--show-arch-info',
'-s',
action='store_true',
help='Whether to print complexity information'
'for each layer in a model')
help='Whether to show model arch information')
args = parser.parse_args()
return args

Expand All @@ -59,7 +56,7 @@ def batch_constructor(flops_model, batch_size, input_shape):
"""Generate a batch of tensors to the model."""
batch = {}

inputs = torch.ones(()).new_empty(
inputs = torch.randn(batch_size, *input_shape).new_empty(
(batch_size, *input_shape),
dtype=next(flops_model.parameters()).dtype,
device=next(flops_model.parameters()).device)
Expand All @@ -68,44 +65,74 @@ def batch_constructor(flops_model, batch_size, input_shape):
return batch


def main():

args = parse_args()

if len(args.shape) == 1:
input_shape = (3, args.shape[0], args.shape[0])
elif len(args.shape) == 2:
input_shape = (3, ) + tuple(args.shape)
else:
raise ValueError('invalid input shape')

def inference(args, input_shape, logger):
model = init_model(
args.config,
checkpoint=None,
device=args.device,
cfg_options=args.cfg_options)

if args.input_constructor == 'batch':
input_constructor = partial(batch_constructor, model, args.batch_size)
else:
input_constructor = None

if hasattr(model, '_forward'):
model.forward = model._forward
else:
raise NotImplementedError(
'FLOPs counter is currently not currently supported with {}'.
format(model.__class__.__name__))

flops, params = get_model_complexity_info(
model,
input_shape,
input_constructor=input_constructor,
print_per_layer_stat=(not args.not_print_per_layer_stat))
if args.batch_size > 1:
outputs = {}
avg_flops = []
logger.info('Running get_flops with batch size specified as {}'.format(
args.batch_size))
batch = batch_constructor(model, args.batch_size, input_shape)
for i in range(args.batch_size):
result = get_model_complexity_info(
model,
input_shape,
inputs=batch['inputs'],
show_table=True,
show_arch=args.show_arch_info)
avg_flops.append(result['flops'])
mean_flops = _format_size(int(np.average(avg_flops)))
outputs['flops_str'] = mean_flops
outputs['params_str'] = result['params_str']
outputs['out_table'] = result['out_table']
outputs['out_arch'] = result['out_arch']
else:
outputs = get_model_complexity_info(
model,
input_shape,
inputs=None,
show_table=True,
show_arch=args.show_arch_info)
return outputs


def main():
args = parse_args()
logger = MMLogger.get_instance(name='MMLogger')

if len(args.input_shape) == 1:
input_shape = (3, args.input_shape[0], args.input_shape[0])
elif len(args.input_shape) == 2:
input_shape = (3, ) + tuple(args.input_shape)
else:
raise ValueError('invalid input shape')

if args.device == 'cuda:0':
assert torch.cuda.is_available(
), 'No valid cuda device detected, please double check...'

outputs = inference(args, input_shape, logger)
flops = outputs['flops_str']
params = outputs['params_str']
split_line = '=' * 30
input_shape = (args.batch_size, ) + input_shape
print(f'{split_line}\nInput shape: {input_shape}\n'
f'Flops: {flops}\nParams: {params}\n{split_line}')
print(outputs['out_table'])
if args.show_arch_info:
print(outputs['out_arch'])
print('!!!Please be cautious if you use the results in papers. '
'You may need to check if all ops are supported and verify that the '
'flops computation is correct.')
Expand Down

0 comments on commit 912d458

Please sign in to comment.