diff --git a/tools/train_net.py b/tools/train_net.py index 3f37d55f5..e4f95f015 100644 --- a/tools/train_net.py +++ b/tools/train_net.py @@ -76,7 +76,7 @@ def train(cfg, local_rank, distributed): return model -def test(cfg, model, distributed): +def run_test(cfg, model, distributed): if distributed: model = model.module torch.cuda.empty_cache() # TODO check if it helps @@ -167,7 +167,7 @@ def main(): model = train(cfg, args.local_rank, args.distributed) if not args.skip_test: - test(cfg, model, args.distributed) + run_test(cfg, model, args.distributed) if __name__ == "__main__":