Skip to content

Commit

Permalink
Tiny fix of mmedit/apis/test.py (#261)
Browse files Browse the repository at this point in the history
* tiny fix

* Tiny Fix, add limited_gpu.

* Tiny Fix

* Tiny Fix

Co-authored-by: liyinshuo <[email protected]>
  • Loading branch information
Yshuo-Li and liyinshuo authored Apr 21, 2021
1 parent f24c9ef commit d1320da
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 4 deletions.
4 changes: 3 additions & 1 deletion docs/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,7 @@ data = dict(
gt_folder='data/val_set5/Set5_mod12',
pipeline=test_pipeline,
scale=scale,
filename_tmpl='{}'),
filename_tmpl='{}')

empty_cache = True # empty cache in every iteration.
```
7 changes: 5 additions & 2 deletions mmedit/apis/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def multi_gpu_test(model,
gpu_collect=False,
save_image=False,
save_path=None,
iteration=None):
iteration=None,
empty_cache=False):
"""Test model with multiple gpus.
This method tests model with multiple gpus and collects the results
Expand All @@ -82,6 +83,7 @@ def multi_gpu_test(model,
save_path (str): The path to save image. Default: None.
iteration (int): Iteration number. It is used for the save image name.
Default: None.
empty_cache (bool): empty cache in every iteration. Default: False.
Returns:
list: The prediction results.
Expand All @@ -105,7 +107,8 @@ def multi_gpu_test(model,
iteration=iteration,
**data)
results.append(result)

if empty_cache:
torch.cuda.empty_cache()
if rank == 0:
# get batch size
for _, v in data.items():
Expand Down
4 changes: 3 additions & 1 deletion tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def main():
model = build_model(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)

args.save_image = args.save_path is not None
empty_cache = cfg.get('empty_cache', False)
if not distributed:
_ = load_checkpoint(model, args.checkpoint, map_location='cpu')
model = MMDataParallel(model, device_ids=[0])
Expand Down Expand Up @@ -115,7 +116,8 @@ def main():
args.tmpdir,
args.gpu_collect,
save_path=args.save_path,
save_image=args.save_image)
save_image=args.save_image,
empty_cache=empty_cache)

if rank == 0:
print('')
Expand Down

0 comments on commit d1320da

Please sign in to comment.