diff --git a/GANDLF/cli/post_training_model_optimization.py b/GANDLF/cli/post_training_model_optimization.py index 0ca261465..6dc2a4310 100644 --- a/GANDLF/cli/post_training_model_optimization.py +++ b/GANDLF/cli/post_training_model_optimization.py @@ -1,16 +1,21 @@ import os +from pathlib import Path +from typing import Optional from GANDLF.compute import create_pytorch_objects from GANDLF.config_manager import ConfigManager from GANDLF.utils import version_check, load_model, optimize_and_save_model -def post_training_model_optimization(model_path: str, config_path: str) -> bool: +def post_training_model_optimization( + model_path: str, config_path: Optional[str] = None, output_dir: Optional[str] = None +) -> bool: """ CLI function to optimize a model for deployment. Args: model_path (str): Path to the model file. - config_path (str): Path to the config file. + config_path (str, optional): Path to the configuration file. + output_dir (str, optional): Output directory to save the optimized model. Returns: bool: True if successful, False otherwise. @@ -26,6 +31,12 @@ def post_training_model_optimization(model_path: str, config_path: str) -> bool: else parameters ) + output_dir = os.path.dirname(model_path) if output_dir is None else output_dir + Path(output_dir).mkdir(parents=True, exist_ok=True) + optimized_model_path = os.path.join( + output_dir, os.path.basename(model_path).replace("pth.tar", "onnx") + ) + # Create PyTorch objects and set onnx_export to True for optimization model, _, _, _, _, parameters = create_pytorch_objects(parameters, device="cpu") parameters["model"]["onnx_export"] = True @@ -35,10 +46,9 @@ def post_training_model_optimization(model_path: str, config_path: str) -> bool: model.load_state_dict(main_dict["model_state_dict"]) # Optimize the model and save it to an ONNX file - optimize_and_save_model(model, parameters, model_path, onnx_export=True) + optimize_and_save_model(model, parameters, optimized_model_path, onnx_export=True) # Check if the optimized model file exists - optimized_model_path = model_path.replace("pth.tar", "onnx") if not os.path.exists(optimized_model_path): print("Error while optimizing the model.") return False diff --git a/GANDLF/entrypoints/optimize_model.py b/GANDLF/entrypoints/optimize_model.py index 7cf5e9a7f..aa1a9a1e0 100644 --- a/GANDLF/entrypoints/optimize_model.py +++ b/GANDLF/entrypoints/optimize_model.py @@ -11,8 +11,12 @@ from GANDLF.entrypoints import append_copyright_to_help -def _optimize_model(model: str, config: Optional[str]): - if post_training_model_optimization(model_path=model, config_path=config): +def _optimize_model( + model: str, config: Optional[str], output_path: Optional[str] = None +): + if post_training_model_optimization( + model_path=model, config_path=config, output_path=output_path + ): print("Post-training model optimization successful.") else: print("Post-training model optimization failed.") @@ -26,6 +30,13 @@ def _optimize_model(model: str, config: Optional[str]): required=True, help="Path to the model file (ending in '.pth.tar') you wish to optimize.", ) +@click.option( + "--output-path", + "-o", + type=click.Path(file_okay=False, dir_okay=True), + required=False, + help="Location to save the optimized model, defaults to location of `model`", +) @click.option( "--config", "-c", @@ -35,9 +46,11 @@ def _optimize_model(model: str, config: Optional[str]): type=click.Path(exists=True, file_okay=True, dir_okay=False), ) @append_copyright_to_help -def new_way(model: str, config: Optional[str]): +def new_way( + model: str, config: Optional[str] = None, output_path: Optional[str] = None +): """Generate optimized versions of trained GaNDLF models.""" - _optimize_model(model=model, config=config) + _optimize_model(model=model, config=config, output_path=output_path) # old-fashioned way of running gandlf via `gandlf_optimizeModel`. @@ -62,6 +75,16 @@ def old_way(): help="Path to the model file (ending in '.pth.tar') you wish to optimize.", required=True, ) + parser.add_argument( + "-o", + "--outputdir", + "--output_path", + metavar="", + type=str, + default=None, + help="Location to save the optimized model, defaults to location of `model`", + required=False, + ) parser.add_argument( "-c", "--config", @@ -74,7 +97,7 @@ def old_way(): ) args = parser.parse_args() - _optimize_model(model=args.model, config=args.config) + _optimize_model(model=args.model, config=args.config, output_path=args.outputdir) if __name__ == "__main__": diff --git a/GANDLF/utils/modelio.py b/GANDLF/utils/modelio.py index a83a6d952..d9c069804 100644 --- a/GANDLF/utils/modelio.py +++ b/GANDLF/utils/modelio.py @@ -26,7 +26,10 @@ def optimize_and_save_model( - model: torch.nn.Module, params: dict, path: str, onnx_export: Optional[bool] = True + model: torch.nn.Module, + params: dict, + output_path: str, + onnx_export: Optional[bool] = True, ) -> None: """ Perform post-training optimization and save it to a file. @@ -34,7 +37,7 @@ def optimize_and_save_model( Args: model (torch.nn.Module): Trained torch model. params (dict): The parameter dictionary. - path (str): The path to save the model dictionary to. + output_path (str): The path to save the optimized model to. onnx_export (Optional[bool]): Whether to export to ONNX and OpenVINO. Defaults to True. """ # Check if ONNX export is enabled in the parameter dictionary @@ -59,9 +62,7 @@ def optimize_and_save_model( num_channel = params["model"]["num_channels"] model_dimension = params["model"]["dimension"] input_shape = params["patch_size"] - onnx_path = path - if not onnx_path.endswith(".onnx"): - onnx_path = onnx_path.replace("pth.tar", "onnx") + onnx_path = output_path.replace(".pth.tar", ".onnx") if model_dimension == 2: dummy_input = torch.randn( diff --git a/testing/entrypoints/test_optimize_model.py b/testing/entrypoints/test_optimize_model.py index 7da6fca0e..f2002c72e 100644 --- a/testing/entrypoints/test_optimize_model.py +++ b/testing/entrypoints/test_optimize_model.py @@ -17,8 +17,28 @@ TmpFile("model.pth.tar", content="123321"), TmpFile("config.yaml", content="foo: bar"), TmpNoEx("path_na"), + TmpDire("output/"), ] test_cases = [ + CliCase( + should_succeed=True, + new_way_lines=[ + # full command with output + "--model model.pth.tar --config config.yaml --output-path output/", + # tests short arg aliases + "-m model.pth.tar -c config.yaml -o output/", + ], + old_way_lines=[ + "--model model.pth.tar --config config.yaml --output_path output/", + "-m model.pth.tar -c config.yaml -o output/", + ], + expected_args={ + "model_path": "model.pth.tar", + "config_path": "config.yaml", + "output_path": "output/", + "output_dir": None, + }, + ), CliCase( should_succeed=True, new_way_lines=[ @@ -31,7 +51,12 @@ "--model model.pth.tar --config config.yaml", "-m model.pth.tar -c config.yaml", ], - expected_args={"model_path": "model.pth.tar", "config_path": "config.yaml"}, + expected_args={ + "model_path": "model.pth.tar", + "config_path": "config.yaml", + "output_dir": None, + "output_path": None, + }, ), CliCase( should_succeed=True, @@ -40,7 +65,12 @@ "-m model.pth.tar" ], old_way_lines=["-m model.pth.tar"], - expected_args={"model_path": "model.pth.tar", "config_path": None}, + expected_args={ + "model_path": "model.pth.tar", + "config_path": None, + "output_path": None, + "output_dir": None, + }, ), CliCase( should_succeed=False, diff --git a/testing/test_full.py b/testing/test_full.py index b9d845916..29cbba099 100644 --- a/testing/test_full.py +++ b/testing/test_full.py @@ -782,7 +782,9 @@ def test_train_inference_optimize_classification_rad_3d(device): # file_config_temp = write_temp_config_path(parameters_temp) model_path = os.path.join(outputDir, all_models_regression[0] + "_best.pth.tar") config_path = os.path.join(outputDir, "parameters.pkl") - optimization_result = post_training_model_optimization(model_path, config_path) + optimization_result = post_training_model_optimization( + model_path, config_path, outputDir + ) assert optimization_result == True, "Optimization should pass" ## testing inference