diff --git a/example/gluon/image_classification.py b/example/gluon/image_classification.py index 8481afb50c1a..a67da3534135 100644 --- a/example/gluon/image_classification.py +++ b/example/gluon/image_classification.py @@ -64,6 +64,9 @@ parser.add_argument('--kvstore', type=str, default='device', help='kvstore to use for trainer/module.') parser.add_argument('--log-interval', type=int, default=50, help='Number of batches to wait before logging.') +parser.add_argument('--profile', action='store_true', + help='Option to turn on memory profiling for front-end, '\ + 'and prints out the memory usage by python function at the end.') opt = parser.parse_args() logging.info(opt) @@ -166,7 +169,7 @@ def train(epochs, ctx): net.save_params('image-classifier-%s-%d.params'%(opt.model, epochs)) -if __name__ == '__main__': +def main(): if opt.mode == 'symbolic': data = mx.sym.var('data') out = net(data) @@ -186,3 +189,16 @@ def train(epochs, ctx): if opt.mode == 'hybrid': net.hybridize() train(opt.epochs, context) + +if __name__ == '__main__': + if opt.profile: + import hotshot, hotshot.stats + prof = hotshot.Profile('image-classifier-%s-%s.prof'%(opt.model, opt.mode)) + prof.runcall(main) + prof.close() + stats = hotshot.stats.load('image-classifier-%s-%s.prof'%(opt.model, opt.mode)) + stats.strip_dirs() + stats.sort_stats('cumtime', 'calls') + stats.print_stats() + else: + main()