-
Notifications
You must be signed in to change notification settings - Fork 356
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Support model complexity computation #779
Conversation
Lint failed. |
success in one stage detector , but failed in two stage detector in mmdet 3.x @ZwwWayne @RangiLyu get_flops.py # Copyright (c) OpenMMLab. All rights reserved.
import argparse
import numpy as np
import torch
from mmengine.config import Config, DictAction
from mmdet.registry import MODELS
from mmdet.utils import register_all_modules
try:
from mmengine.analysis import get_model_complexity_info
except ImportError:
raise ImportError('Please upgrade mmcv to >0.6.2')
def parse_args():
parser = argparse.ArgumentParser(description='Train a detector')
parser.add_argument('config', help='train config file path')
parser.add_argument(
'--shape',
type=int,
nargs='+',
default=[1280, 800],
help='input image size')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--size-divisor',
type=int,
default=32,
help='Pad the input image, the minimum size that is divisible '
'by size_divisor, -1 means do not pad the image.')
args = parser.parse_args()
return args
def main():
register_all_modules()
args = parse_args()
if len(args.shape) == 1:
h = w = args.shape[0]
elif len(args.shape) == 2:
h, w = args.shape
else:
raise ValueError('invalid input shape')
ori_shape = (3, h, w)
divisor = args.size_divisor
if divisor > 0:
h = int(np.ceil(h / divisor)) * divisor
w = int(np.ceil(w / divisor)) * divisor
input_shape = (3, h, w)
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
model = MODELS.build(cfg.model)
# if torch.cuda.is_available():
# model.cuda()
model.eval()
flops, activations, params, complexity_table, complexity_str = get_model_complexity_info(model, input_shape, show_table=True, show_str=True)
split_line = '=' * 30
if divisor > 0 and \
input_shape != ori_shape:
print(f'{split_line}\nUse size divisor set input shape '
f'from {ori_shape} to {input_shape}\n')
print(f'{split_line}\nInput shape: {input_shape}\n'
f'Flops: {flops}\nParams: {params}\n{split_line}')
print(activations)
print(complexity_table)
# print(complexity_str)
print('!!!Please be cautious if you use the results in papers. '
'You may need to check if all ops are supported and verify that the '
'flops computation is correct.')
if __name__ == '__main__':
main() (mmroate) ➜ mmdetection git:(dev-3.x) ✗ python tools/analysis_tools/get_flops.py configs/yolo/yolov3_d53_8xb8-320-273e_coco.py
12/12 13:20:29 - mmengine - WARNING - Unsupported operator aten::leaky_relu_ encountered 72 time(s)
12/12 13:20:29 - mmengine - WARNING - Unsupported operator aten::add encountered 23 time(s)
12/12 13:20:29 - mmengine - WARNING - The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
bbox_head.loss_cls, bbox_head.loss_conf, bbox_head.loss_wh, bbox_head.loss_xy, data_preprocessor
12/12 13:20:30 - mmengine - WARNING - Unsupported operator aten::batch_norm encountered 72 time(s)
12/12 13:20:30 - mmengine - WARNING - Unsupported operator aten::leaky_relu_ encountered 72 time(s)
12/12 13:20:30 - mmengine - WARNING - Unsupported operator aten::add encountered 23 time(s)
12/12 13:20:30 - mmengine - WARNING - Unsupported operator aten::upsample_nearest2d encountered 2 time(s)
12/12 13:20:30 - mmengine - WARNING - The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
bbox_head.loss_cls, bbox_head.loss_conf, bbox_head.loss_wh, bbox_head.loss_xy, data_preprocessor
==============================
Input shape: (3, 1280, 800)
Flops: 0.195T
Params: 61.949M
==============================
0.232G
| module | #parameters or shape | #flops | #activations |
|:---------------------------------|:-----------------------|:-----------|:---------------|
| model | 61.949M | 0.195T | 0.232G |
| backbone | 40.585M | 0.145T | 0.194G |
| backbone.conv1 | 0.928K | 0.95G | 32.768M |
| backbone.conv1.conv | 0.864K | 0.885G | 32.768M |
| backbone.conv1.bn | 64 | 65.536M | 0 |
| backbone.conv_res_block1 | 39.232K | 10.043G | 40.96M |
| backbone.conv_res_block1.conv | 18.56K | 4.751G | 16.384M |
| backbone.conv_res_block1.res0 | 20.672K | 5.292G | 24.576M |
| backbone.conv_res_block2 | 0.239M | 15.27G | 32.768M |
| backbone.conv_res_block2.conv | 73.984K | 4.735G | 8.192M |
| backbone.conv_res_block2.res0 | 82.304K | 5.267G | 12.288M |
| backbone.conv_res_block2.res1 | 82.304K | 5.267G | 12.288M |
| backbone.conv_res_block3 | 2.923M | 46.768G | 53.248M |
| backbone.conv_res_block3.conv | 0.295M | 4.727G | 4.096M |
| backbone.conv_res_block3.res0 | 0.328M | 5.255G | 6.144M |
| backbone.conv_res_block3.res1 | 0.328M | 5.255G | 6.144M |
| backbone.conv_res_block3.res2 | 0.328M | 5.255G | 6.144M |
| backbone.conv_res_block3.res3 | 0.328M | 5.255G | 6.144M |
| backbone.conv_res_block3.res4 | 0.328M | 5.255G | 6.144M |
| backbone.conv_res_block3.res5 | 0.328M | 5.255G | 6.144M |
| backbone.conv_res_block3.res6 | 0.328M | 5.255G | 6.144M |
| backbone.conv_res_block3.res7 | 0.328M | 5.255G | 6.144M |
| backbone.conv_res_block4 | 11.679M | 46.715G | 26.624M |
| backbone.conv_res_block4.conv | 1.181M | 4.723G | 2.048M |
| backbone.conv_res_block4.res0 | 1.312M | 5.249G | 3.072M |
| backbone.conv_res_block4.res1 | 1.312M | 5.249G | 3.072M |
| backbone.conv_res_block4.res2 | 1.312M | 5.249G | 3.072M |
| backbone.conv_res_block4.res3 | 1.312M | 5.249G | 3.072M |
| backbone.conv_res_block4.res4 | 1.312M | 5.249G | 3.072M |
| backbone.conv_res_block4.res5 | 1.312M | 5.249G | 3.072M |
| backbone.conv_res_block4.res6 | 1.312M | 5.249G | 3.072M |
| backbone.conv_res_block4.res7 | 1.312M | 5.249G | 3.072M |
| backbone.conv_res_block5 | 25.704M | 25.704G | 7.168M |
| backbone.conv_res_block5.conv | 4.721M | 4.721G | 1.024M |
| backbone.conv_res_block5.res0 | 5.246M | 5.246G | 1.536M |
| backbone.conv_res_block5.res1 | 5.246M | 5.246G | 1.536M |
| backbone.conv_res_block5.res2 | 5.246M | 5.246G | 1.536M |
| backbone.conv_res_block5.res3 | 5.246M | 5.246G | 1.536M |
| neck | 14.71M | 33.871G | 25.856M |
| neck.detect1 | 11.017M | 11.017G | 3.584M |
| neck.detect1.conv1 | 0.525M | 0.525G | 0.512M |
| neck.detect1.conv2 | 4.721M | 4.721G | 1.024M |
| neck.detect1.conv3 | 0.525M | 0.525G | 0.512M |
| neck.detect1.conv4 | 4.721M | 4.721G | 1.024M |
| neck.detect1.conv5 | 0.525M | 0.525G | 0.512M |
| neck.conv1 | 0.132M | 0.132G | 0.256M |
| neck.conv1.conv | 0.131M | 0.131G | 0.256M |
| neck.conv1.bn | 0.512K | 0.512M | 0 |
| neck.detect2 | 2.822M | 11.287G | 7.168M |
| neck.detect2.conv1 | 0.197M | 0.788G | 1.024M |
| neck.detect2.conv2 | 1.181M | 4.723G | 2.048M |
| neck.detect2.conv3 | 0.132M | 0.526G | 1.024M |
| neck.detect2.conv4 | 1.181M | 4.723G | 2.048M |
| neck.detect2.conv5 | 0.132M | 0.526G | 1.024M |
| neck.conv2 | 33.024K | 0.132G | 0.512M |
| neck.conv2.conv | 32.768K | 0.131G | 0.512M |
| neck.conv2.bn | 0.256K | 1.024M | 0 |
| neck.detect3 | 0.706M | 11.301G | 14.336M |
| neck.detect3.conv1 | 49.408K | 0.791G | 2.048M |
| neck.detect3.conv2 | 0.295M | 4.727G | 4.096M |
| neck.detect3.conv3 | 33.024K | 0.528G | 2.048M |
| neck.detect3.conv4 | 0.295M | 4.727G | 4.096M |
| neck.detect3.conv5 | 33.024K | 0.528G | 2.048M |
| bbox_head | 6.654M | 15.998G | 12.523M |
| bbox_head.convs_bridge | 6.197M | 14.17G | 7.168M |
| bbox_head.convs_bridge.0 | 4.721M | 4.721G | 1.024M |
| bbox_head.convs_bridge.1 | 1.181M | 4.723G | 2.048M |
| bbox_head.convs_bridge.2 | 0.295M | 4.727G | 4.096M |
| bbox_head.convs_pred | 0.458M | 1.828G | 5.355M |
| bbox_head.convs_pred.0 | 0.261M | 0.261G | 0.255M |
| bbox_head.convs_pred.1 | 0.131M | 0.522G | 1.02M |
| bbox_head.convs_pred.2 | 65.535K | 1.044G | 4.08M |
!!!Please be cautious if you use the results in papers. You may need to check if all ops are supported and verify that the flops computation is correct.
(mmroate) ➜ mmdetection git:(dev-3.x) ✗ (mmroate) ➜ mmdetection git:(dev-3.x) ✗ python tools/analysis_tools/get_flops.py configs/faster_rcnn/faster-rcnn_r101_fpn_1x_coco.py
Traceback (most recent call last):
File "tools/analysis_tools/get_flops.py", line 92, in <module>
main()
File "tools/analysis_tools/get_flops.py", line 73, in main
flops, activations, params, complexity_table, complexity_str = get_model_complexity_info(model, input_shape, show_table=True, show_str=True)
File "/home/ubuntu/mmroate-1.x/mmengine/mmengine/analysis/print_helper.py", line 668, in get_model_complexity_info
flops = _format_size(flop_handler.total())
File "/home/ubuntu/mmroate-1.x/mmengine/mmengine/analysis/jit_analysis.py", line 259, in total
stats = self._analyze()
File "/home/ubuntu/mmroate-1.x/mmengine/mmengine/analysis/jit_analysis.py", line 550, in _analyze
graph = _get_scoped_trace_graph(self._model, self._inputs,
File "/home/ubuntu/mmroate-1.x/mmengine/mmengine/analysis/jit_analysis.py", line 189, in _get_scoped_trace_graph
graph, _ = _get_trace_graph(module, inputs)
File "/home/ubuntu/miniconda3/envs/mmroate/lib/python3.8/site-packages/torch/jit/_trace.py", line 1166, in _get_trace_graph
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
File "/home/ubuntu/miniconda3/envs/mmroate/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/ubuntu/miniconda3/envs/mmroate/lib/python3.8/site-packages/torch/jit/_trace.py", line 127, in forward
graph, out = torch._C._create_graph_by_tracing(
File "/home/ubuntu/miniconda3/envs/mmroate/lib/python3.8/site-packages/torch/jit/_trace.py", line 118, in wrapper
outs.append(self.inner(*trace_inputs))
File "/home/ubuntu/miniconda3/envs/mmroate/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1120, in _call_impl
result = forward_call(*input, **kwargs)
File "/home/ubuntu/miniconda3/envs/mmroate/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1090, in _slow_forward
result = self.forward(*input, **kwargs)
File "/home/ubuntu/mmroate-1.x/mmdetection/mmdet/models/detectors/base.py", line 96, in forward
return self._forward(inputs, data_samples)
File "/home/ubuntu/mmroate-1.x/mmdetection/mmdet/models/detectors/two_stage.py", line 131, in _forward
rpn_results_list = self.rpn_head.predict(
File "/home/ubuntu/mmroate-1.x/mmdetection/mmdet/models/dense_heads/base_dense_head.py", line 191, in predict
batch_img_metas = [
TypeError: 'NoneType' object is not iterable```
|
a591787
to
14524f0
Compare
Usage in mmcls
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from mmengine.analysis import get_model_complexity_info
from mmengine import Config
from mmcls.models import build_classifier
def parse_args():
parser = argparse.ArgumentParser(description='Get model flops and params')
parser.add_argument('config', help='config file path')
parser.add_argument(
'--shape',
type=int,
nargs='+',
default=[224, 224],
help='input image size')
args = parser.parse_args()
return args
def main():
args = parse_args()
if len(args.shape) == 1:
input_shape = (3, args.shape[0], args.shape[0])
elif len(args.shape) == 2:
input_shape = (3, ) + tuple(args.shape)
else:
raise ValueError('invalid input shape')
cfg = Config.fromfile(args.config)
model = build_classifier(cfg.model)
model.eval()
if hasattr(model, 'extract_feat'):
model.forward = model.extract_feat
else:
raise NotImplementedError(
'FLOPs counter is currently not currently supported with {}'.
format(model.__class__.__name__))
analysis_results = get_model_complexity_info(model,input_shape,)
flops = analysis_results['flops_str']
params = analysis_results['params_str']
activations = analysis_results['activations_str']
out_table = analysis_results['out_table']
out_arch = analysis_results['out_arch']
print(out_table)
print(out_arch)
split_line = '=' * 30
print(f'{split_line}\nInput shape: {input_shape}\n'
f'Flops: {flops}\nParams: {params}\n'
f'Activation: {activations}\n{split_line}')
print('!!!Please be cautious if you use the results in papers. '
'You may need to check if all ops are supported and verify that the '
'flops computation is correct.')
if __name__ == '__main__':
main()
|
Usage in mmdet
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import numpy as np
import torch
from mmengine.config import Config, DictAction
from mmdet.registry import MODELS
from mmengine import Config
from functools import partial
from mmdet.utils import register_all_modules
from mmengine.runner import Runner
from mmengine.logging import MMLogger
from mmengine.analysis import get_model_complexity_info
def parse_args():
parser = argparse.ArgumentParser(description='Train a detector')
parser.add_argument('config', help='train config file path')
parser.add_argument(
'--shape',
type=int,
nargs='+',
default=[1280, 800],
help='input image size')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--size-divisor',
type=int,
default=32,
help='Pad the input image, the minimum size that is divisible '
'by size_divisor, -1 means do not pad the image.')
args = parser.parse_args()
return args
def main():
register_all_modules()
args = parse_args()
if len(args.shape) == 1:
h = w = args.shape[0]
elif len(args.shape) == 2:
h, w = args.shape
else:
raise ValueError('invalid input shape')
ori_shape = (3, h, w)
divisor = args.size_divisor
if divisor > 0:
h = int(np.ceil(h / divisor)) * divisor
w = int(np.ceil(w / divisor)) * divisor
input_shape = (3, h, w)
try:
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
model = MODELS.build(cfg.model)
if torch.cuda.is_available():
model.cuda()
model.eval()
analysis_results = get_model_complexity_info(model, input_shape)
flops = analysis_results['flops_str']
activations = analysis_results['activations_str']
params = analysis_results['params_str']
except:
logger = MMLogger.get_instance(name='MMLogger')
logger.warning('Direct get flops failed, try to get flops with data')
cfg = Config.fromfile(args.config)
data_loader = Runner.build_dataloader(cfg.val_dataloader)
data_batch = next(iter(data_loader))
model = MODELS.build(cfg.model)
_forward = model.forward
data = model.data_preprocessor(data_batch)
model.forward = partial(_forward, data_samples=data['data_samples'])
analysis_results = get_model_complexity_info(
model, input_shape, data['inputs'])
flops = analysis_results['flops_str']
activations = analysis_results['activations_str']
params = analysis_results['params_str']
print(analysis_results['out_table'])
print(analysis_results['out_arch'])
split_line = '=' * 30
if divisor > 0 and \
input_shape != ori_shape:
print(f'{split_line}\nUse size divisor set input shape '
f'from {ori_shape} to {input_shape}\n')
print(f'{split_line}\nInput shape: {input_shape}\n'
f'Flops: {flops}\nParams: {params}\n{split_line}')
print('!!!Please be cautious if you use the results in papers. '
'You may need to check if all ops are supported and verify that the '
'flops computation is correct.')
if __name__ == '__main__':
main()
|
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #779 +/- ##
=======================================
Coverage ? 76.82%
=======================================
Files ? 138
Lines ? 10791
Branches ? 2154
=======================================
Hits ? 8290
Misses ? 2143
Partials ? 358
Flags with carried forward coverage won't be shown. Click here to find out more. Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
The reason why ut did not fail was that |
Implementation of model complexity analysis.
Similar usage with
mmcv.cnn.get_model_complexity_info