diff --git a/maskrcnn_benchmark/utils/checkpoint.py b/maskrcnn_benchmark/utils/checkpoint.py index dc403f5db..2af2565ed 100644 --- a/maskrcnn_benchmark/utils/checkpoint.py +++ b/maskrcnn_benchmark/utils/checkpoint.py @@ -49,8 +49,8 @@ def save(self, name, **kwargs): torch.save(data, save_file) self.tag_last_checkpoint(save_file) - def load(self, f=None): - if self.has_checkpoint(): + def load(self, f=None, use_latest=True): + if self.has_checkpoint() and use_latest: # override argument with existing checkpoint f = self.get_checkpoint_file() if not f: diff --git a/tools/test_net.py b/tools/test_net.py index c666a4655..7a2d3f7b9 100644 --- a/tools/test_net.py +++ b/tools/test_net.py @@ -33,6 +33,11 @@ def main(): help="path to config file", ) parser.add_argument("--local_rank", type=int, default=0) + parser.add_argument( + "--ckpt", + help="The path to the checkpoint for test, default is the latest checkpoint.", + default=None, + ) parser.add_argument( "opts", help="Modify config options using the command-line", @@ -73,7 +78,7 @@ def main(): output_dir = cfg.OUTPUT_DIR checkpointer = DetectronCheckpointer(cfg, model, save_dir=output_dir) - _ = checkpointer.load(cfg.MODEL.WEIGHT) + _ = checkpointer.load(args.ckpt, use_latest=args.ckpt is None) iou_types = ("bbox",) if cfg.MODEL.MASK_ON: