-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add benchmark.py #35
Add benchmark.py #35
Changes from 3 commits
7894e6e
2dc0e2e
233993e
a2870e4
433585d
e2bcbe3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -306,6 +306,12 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} con | |
CUDA_VISIBLE_DEVICES=4,5,6,7 ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} config2.py ${WORK_DIR} 4 | ||
``` | ||
|
||
## Benchmark | ||
You can get average training time for an iteration, we only care about the model training, not including the IO time and pre-processing time. | ||
```shell | ||
python tools/benchmark.py ${MMPOSE_CONFIG_FILE} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. script name |
||
``` | ||
|
||
## Tutorials | ||
|
||
Currently, we provide some tutorials for users to [finetune model](tutorials/finetune.md), | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import argparse | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. rename the file to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since this script is for inference, let's change the unit back to |
||
import time | ||
|
||
import torch | ||
from mmcv import Config | ||
from mmcv.parallel import MMDataParallel | ||
|
||
from mmpose.core import wrap_fp16_model | ||
from mmpose.datasets import build_dataloader, build_dataset | ||
from mmpose.models import build_posenet | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser( | ||
description='MMPose benchmark a recognizer') | ||
parser.add_argument('config', help='test config file path') | ||
parser.add_argument( | ||
'--log-interval', default=10, help='interval of logging') | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def main(): | ||
args = parse_args() | ||
|
||
cfg = Config.fromfile(args.config) | ||
# set cudnn_benchmark | ||
if cfg.get('cudnn_benchmark', False): | ||
torch.backends.cudnn.benchmark = True | ||
|
||
# build the dataloader | ||
dataset = build_dataset(cfg.data.train) | ||
data_loader = build_dataloader( | ||
dataset, | ||
samples_per_gpu=cfg.data.samples_per_gpu, | ||
workers_per_gpu=cfg.data.workers_per_gpu, | ||
dist=False, | ||
shuffle=False) | ||
|
||
# build the model and load checkpoint | ||
model = build_posenet(cfg.model) | ||
fp16_cfg = cfg.get('fp16', None) | ||
if fp16_cfg is not None: | ||
wrap_fp16_model(model) | ||
model = MMDataParallel(model, device_ids=[0]) | ||
|
||
# the first several iterations may be very slow so skip them | ||
num_warmup = 5 | ||
pure_inf_time = 0 | ||
|
||
# benchmark with total batch and take the average | ||
for i, data in enumerate(data_loader): | ||
|
||
torch.cuda.synchronize() | ||
start_time = time.perf_counter() | ||
|
||
model(return_loss=True, **data) | ||
|
||
torch.cuda.synchronize() | ||
elapsed = time.perf_counter() - start_time | ||
|
||
if i >= num_warmup: | ||
pure_inf_time += elapsed | ||
if (i + 1) % args.log_interval == 0: | ||
its = pure_inf_time / (i + 1 - num_warmup) | ||
print( | ||
f'Done batch [{i + 1:<3}], {its:.2f} s / iter', | ||
flush=True) | ||
print(f'Overall average: {its:.2f} s / iter', flush=True) | ||
print(f'Total time: {pure_inf_time:.2f} s', flush=True) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.