diff --git a/GANDLF/entrypoints/optimize_model.py b/GANDLF/entrypoints/optimize_model.py index 7cf5e9a7f..aa1a9a1e0 100644 --- a/GANDLF/entrypoints/optimize_model.py +++ b/GANDLF/entrypoints/optimize_model.py @@ -11,8 +11,12 @@ from GANDLF.entrypoints import append_copyright_to_help -def _optimize_model(model: str, config: Optional[str]): - if post_training_model_optimization(model_path=model, config_path=config): +def _optimize_model( + model: str, config: Optional[str], output_path: Optional[str] = None +): + if post_training_model_optimization( + model_path=model, config_path=config, output_path=output_path + ): print("Post-training model optimization successful.") else: print("Post-training model optimization failed.") @@ -26,6 +30,13 @@ def _optimize_model(model: str, config: Optional[str]): required=True, help="Path to the model file (ending in '.pth.tar') you wish to optimize.", ) +@click.option( + "--output-path", + "-o", + type=click.Path(file_okay=False, dir_okay=True), + required=False, + help="Location to save the optimized model, defaults to location of `model`", +) @click.option( "--config", "-c", @@ -35,9 +46,11 @@ def _optimize_model(model: str, config: Optional[str]): type=click.Path(exists=True, file_okay=True, dir_okay=False), ) @append_copyright_to_help -def new_way(model: str, config: Optional[str]): +def new_way( + model: str, config: Optional[str] = None, output_path: Optional[str] = None +): """Generate optimized versions of trained GaNDLF models.""" - _optimize_model(model=model, config=config) + _optimize_model(model=model, config=config, output_path=output_path) # old-fashioned way of running gandlf via `gandlf_optimizeModel`. @@ -62,6 +75,16 @@ def old_way(): help="Path to the model file (ending in '.pth.tar') you wish to optimize.", required=True, ) + parser.add_argument( + "-o", + "--outputdir", + "--output_path", + metavar="", + type=str, + default=None, + help="Location to save the optimized model, defaults to location of `model`", + required=False, + ) parser.add_argument( "-c", "--config", @@ -74,7 +97,7 @@ def old_way(): ) args = parser.parse_args() - _optimize_model(model=args.model, config=args.config) + _optimize_model(model=args.model, config=args.config, output_path=args.outputdir) if __name__ == "__main__":