From 8efe7f19c71c070589451ea8fc2e423624afaf52 Mon Sep 17 00:00:00 2001 From: scap3yvt <149599669+scap3yvt@users.noreply.github.com> Date: Tue, 9 Jul 2024 13:40:46 -0400 Subject: [PATCH 01/11] added output_dir for model optimization --- GANDLF/cli/post_training_model_optimization.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/GANDLF/cli/post_training_model_optimization.py b/GANDLF/cli/post_training_model_optimization.py index 0ca261465..b82d20739 100644 --- a/GANDLF/cli/post_training_model_optimization.py +++ b/GANDLF/cli/post_training_model_optimization.py @@ -1,16 +1,21 @@ import os +from pathlib import Path +from typing import Optional from GANDLF.compute import create_pytorch_objects from GANDLF.config_manager import ConfigManager from GANDLF.utils import version_check, load_model, optimize_and_save_model -def post_training_model_optimization(model_path: str, config_path: str) -> bool: +def post_training_model_optimization( + model_path: str, config_path: Optional[str] = None, output_dir: Optional[str] = None +) -> bool: """ CLI function to optimize a model for deployment. Args: model_path (str): Path to the model file. - config_path (str): Path to the config file. + config_path (str, optional): Path to the configuration file. + output_dir (str, optional): Output directory to save the optimized model. Returns: bool: True if successful, False otherwise. @@ -26,6 +31,12 @@ def post_training_model_optimization(model_path: str, config_path: str) -> bool: else parameters ) + output_dir = os.path.dirname(model_path) if output_dir is None else output_dir + Path.mkdir(output_dir, parents=True, exist_ok=True) + optimized_model_path = os.path.join( + output_dir, os.path.basename(model_path).replace("pth.tar", "onnx") + ) + # Create PyTorch objects and set onnx_export to True for optimization model, _, _, _, _, parameters = create_pytorch_objects(parameters, device="cpu") parameters["model"]["onnx_export"] = True @@ -35,10 +46,9 @@ def post_training_model_optimization(model_path: str, config_path: str) -> bool: model.load_state_dict(main_dict["model_state_dict"]) # Optimize the model and save it to an ONNX file - optimize_and_save_model(model, parameters, model_path, onnx_export=True) + optimize_and_save_model(model, parameters, optimized_model_path, onnx_export=True) # Check if the optimized model file exists - optimized_model_path = model_path.replace("pth.tar", "onnx") if not os.path.exists(optimized_model_path): print("Error while optimizing the model.") return False From 3ef425ce73d00f5f3a8bbd0ca1ecff7ed98ef389 Mon Sep 17 00:00:00 2001 From: scap3yvt <149599669+scap3yvt@users.noreply.github.com> Date: Tue, 9 Jul 2024 13:41:04 -0400 Subject: [PATCH 02/11] added logic to take the output_path directly --- GANDLF/utils/modelio.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/GANDLF/utils/modelio.py b/GANDLF/utils/modelio.py index a83a6d952..d9c069804 100644 --- a/GANDLF/utils/modelio.py +++ b/GANDLF/utils/modelio.py @@ -26,7 +26,10 @@ def optimize_and_save_model( - model: torch.nn.Module, params: dict, path: str, onnx_export: Optional[bool] = True + model: torch.nn.Module, + params: dict, + output_path: str, + onnx_export: Optional[bool] = True, ) -> None: """ Perform post-training optimization and save it to a file. @@ -34,7 +37,7 @@ def optimize_and_save_model( Args: model (torch.nn.Module): Trained torch model. params (dict): The parameter dictionary. - path (str): The path to save the model dictionary to. + output_path (str): The path to save the optimized model to. onnx_export (Optional[bool]): Whether to export to ONNX and OpenVINO. Defaults to True. """ # Check if ONNX export is enabled in the parameter dictionary @@ -59,9 +62,7 @@ def optimize_and_save_model( num_channel = params["model"]["num_channels"] model_dimension = params["model"]["dimension"] input_shape = params["patch_size"] - onnx_path = path - if not onnx_path.endswith(".onnx"): - onnx_path = onnx_path.replace("pth.tar", "onnx") + onnx_path = output_path.replace(".pth.tar", ".onnx") if model_dimension == 2: dummy_input = torch.randn( From 40b0cdbfbdeadcbeef3dedf97ba103074f8135c1 Mon Sep 17 00:00:00 2001 From: scap3yvt <149599669+scap3yvt@users.noreply.github.com> Date: Tue, 9 Jul 2024 13:43:16 -0400 Subject: [PATCH 03/11] updated cli --- GANDLF/entrypoints/optimize_model.py | 33 +++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/GANDLF/entrypoints/optimize_model.py b/GANDLF/entrypoints/optimize_model.py index 7cf5e9a7f..aa1a9a1e0 100644 --- a/GANDLF/entrypoints/optimize_model.py +++ b/GANDLF/entrypoints/optimize_model.py @@ -11,8 +11,12 @@ from GANDLF.entrypoints import append_copyright_to_help -def _optimize_model(model: str, config: Optional[str]): - if post_training_model_optimization(model_path=model, config_path=config): +def _optimize_model( + model: str, config: Optional[str], output_path: Optional[str] = None +): + if post_training_model_optimization( + model_path=model, config_path=config, output_path=output_path + ): print("Post-training model optimization successful.") else: print("Post-training model optimization failed.") @@ -26,6 +30,13 @@ def _optimize_model(model: str, config: Optional[str]): required=True, help="Path to the model file (ending in '.pth.tar') you wish to optimize.", ) +@click.option( + "--output-path", + "-o", + type=click.Path(file_okay=False, dir_okay=True), + required=False, + help="Location to save the optimized model, defaults to location of `model`", +) @click.option( "--config", "-c", @@ -35,9 +46,11 @@ def _optimize_model(model: str, config: Optional[str]): type=click.Path(exists=True, file_okay=True, dir_okay=False), ) @append_copyright_to_help -def new_way(model: str, config: Optional[str]): +def new_way( + model: str, config: Optional[str] = None, output_path: Optional[str] = None +): """Generate optimized versions of trained GaNDLF models.""" - _optimize_model(model=model, config=config) + _optimize_model(model=model, config=config, output_path=output_path) # old-fashioned way of running gandlf via `gandlf_optimizeModel`. @@ -62,6 +75,16 @@ def old_way(): help="Path to the model file (ending in '.pth.tar') you wish to optimize.", required=True, ) + parser.add_argument( + "-o", + "--outputdir", + "--output_path", + metavar="", + type=str, + default=None, + help="Location to save the optimized model, defaults to location of `model`", + required=False, + ) parser.add_argument( "-c", "--config", @@ -74,7 +97,7 @@ def old_way(): ) args = parser.parse_args() - _optimize_model(model=args.model, config=args.config) + _optimize_model(model=args.model, config=args.config, output_path=args.outputdir) if __name__ == "__main__": From 7deb9b41c7cbb4b2d90b918f8cd1e387cd8ea1cc Mon Sep 17 00:00:00 2001 From: scap3yvt <149599669+scap3yvt@users.noreply.github.com> Date: Wed, 10 Jul 2024 14:51:58 -0400 Subject: [PATCH 04/11] added test for output-path --- testing/entrypoints/test_optimize_model.py | 25 +++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/testing/entrypoints/test_optimize_model.py b/testing/entrypoints/test_optimize_model.py index 7da6fca0e..b53b58ea4 100644 --- a/testing/entrypoints/test_optimize_model.py +++ b/testing/entrypoints/test_optimize_model.py @@ -17,8 +17,27 @@ TmpFile("model.pth.tar", content="123321"), TmpFile("config.yaml", content="foo: bar"), TmpNoEx("path_na"), + TmpDire("output/"), ] test_cases = [ + CliCase( + should_succeed=True, + new_way_lines=[ + # full command with output + "--model model.pth.tar --config config.yaml --output-path output/", + # tests short arg aliases + "-m model.pth.tar -c config.yaml -o output/", + ], + old_way_lines=[ + "--model model.pth.tar --config config.yaml --output_path output/", + "-m model.pth.tar -c config.yaml -o output/", + ], + expected_args={ + "model_path": "model.pth.tar", + "config_path": "config.yaml", + "output_path": "output/", + }, + ), CliCase( should_succeed=True, new_way_lines=[ @@ -31,7 +50,11 @@ "--model model.pth.tar --config config.yaml", "-m model.pth.tar -c config.yaml", ], - expected_args={"model_path": "model.pth.tar", "config_path": "config.yaml"}, + expected_args={ + "model_path": "model.pth.tar", + "config_path": "config.yaml", + "output_path": None, + }, ), CliCase( should_succeed=True, From 0d887cde3e5daefb97c9e2d5bb3cd702d931f36d Mon Sep 17 00:00:00 2001 From: scap3yvt <149599669+scap3yvt@users.noreply.github.com> Date: Wed, 10 Jul 2024 15:36:53 -0400 Subject: [PATCH 05/11] pass `outputDir` separately --- testing/test_full.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/testing/test_full.py b/testing/test_full.py index 957ad52f2..624c4f68e 100644 --- a/testing/test_full.py +++ b/testing/test_full.py @@ -782,7 +782,7 @@ def test_train_inference_optimize_classification_rad_3d(device): # file_config_temp = write_temp_config_path(parameters_temp) model_path = os.path.join(outputDir, all_models_regression[0] + "_best.pth.tar") config_path = os.path.join(outputDir, "parameters.pkl") - optimization_result = post_training_model_optimization(model_path, config_path) + optimization_result = post_training_model_optimization(model_path, config_path, outputDir) assert optimization_result == True, "Optimization should pass" ## testing inference From 4d748887259479994d8f68c2801f9d6210644cf9 Mon Sep 17 00:00:00 2001 From: scap3yvt <149599669+scap3yvt@users.noreply.github.com> Date: Wed, 10 Jul 2024 15:52:31 -0400 Subject: [PATCH 06/11] lint --- testing/test_full.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/testing/test_full.py b/testing/test_full.py index 624c4f68e..87af4f72c 100644 --- a/testing/test_full.py +++ b/testing/test_full.py @@ -782,7 +782,9 @@ def test_train_inference_optimize_classification_rad_3d(device): # file_config_temp = write_temp_config_path(parameters_temp) model_path = os.path.join(outputDir, all_models_regression[0] + "_best.pth.tar") config_path = os.path.join(outputDir, "parameters.pkl") - optimization_result = post_training_model_optimization(model_path, config_path, outputDir) + optimization_result = post_training_model_optimization( + model_path, config_path, outputDir + ) assert optimization_result == True, "Optimization should pass" ## testing inference From 4ce39d263a4abe01ed02a413dba890667af391dd Mon Sep 17 00:00:00 2001 From: scap3yvt <149599669+scap3yvt@users.noreply.github.com> Date: Wed, 10 Jul 2024 16:08:27 -0400 Subject: [PATCH 07/11] syntax fixed --- GANDLF/cli/post_training_model_optimization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GANDLF/cli/post_training_model_optimization.py b/GANDLF/cli/post_training_model_optimization.py index b82d20739..6dc2a4310 100644 --- a/GANDLF/cli/post_training_model_optimization.py +++ b/GANDLF/cli/post_training_model_optimization.py @@ -32,7 +32,7 @@ def post_training_model_optimization( ) output_dir = os.path.dirname(model_path) if output_dir is None else output_dir - Path.mkdir(output_dir, parents=True, exist_ok=True) + Path(output_dir).mkdir(parents=True, exist_ok=True) optimized_model_path = os.path.join( output_dir, os.path.basename(model_path).replace("pth.tar", "onnx") ) From 5853763e246e78d5443c7876e84e34f3c16fd6c2 Mon Sep 17 00:00:00 2001 From: scap3yvt <149599669+scap3yvt@users.noreply.github.com> Date: Thu, 11 Jul 2024 09:21:05 -0400 Subject: [PATCH 08/11] fixed check for params --- testing/entrypoints/test_optimize_model.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/testing/entrypoints/test_optimize_model.py b/testing/entrypoints/test_optimize_model.py index b53b58ea4..4a40f80df 100644 --- a/testing/entrypoints/test_optimize_model.py +++ b/testing/entrypoints/test_optimize_model.py @@ -35,7 +35,7 @@ expected_args={ "model_path": "model.pth.tar", "config_path": "config.yaml", - "output_path": "output/", + "output_dir": "output/", }, ), CliCase( @@ -53,7 +53,7 @@ expected_args={ "model_path": "model.pth.tar", "config_path": "config.yaml", - "output_path": None, + "output_dir": None, }, ), CliCase( @@ -63,7 +63,12 @@ "-m model.pth.tar" ], old_way_lines=["-m model.pth.tar"], - expected_args={"model_path": "model.pth.tar", "config_path": None}, + expected_args={ + "model_path": "model.pth.tar", + "config_path": None, + "output_path": None, + "output_dir": None, + }, ), CliCase( should_succeed=False, From d11c7693af48ce2da0e7aff53f7b34d49d0d0b56 Mon Sep 17 00:00:00 2001 From: scap3yvt <149599669+scap3yvt@users.noreply.github.com> Date: Thu, 11 Jul 2024 10:42:21 -0400 Subject: [PATCH 09/11] tests updated --- testing/entrypoints/test_optimize_model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/testing/entrypoints/test_optimize_model.py b/testing/entrypoints/test_optimize_model.py index 4a40f80df..54c3974db 100644 --- a/testing/entrypoints/test_optimize_model.py +++ b/testing/entrypoints/test_optimize_model.py @@ -36,6 +36,7 @@ "model_path": "model.pth.tar", "config_path": "config.yaml", "output_dir": "output/", + "output_path": "output/", }, ), CliCase( @@ -54,6 +55,7 @@ "model_path": "model.pth.tar", "config_path": "config.yaml", "output_dir": None, + "output_path": None, }, ), CliCase( From d32e237bd291acd0ddfd8ca8581fe2f23f7d90a0 Mon Sep 17 00:00:00 2001 From: scap3yvt <149599669+scap3yvt@users.noreply.github.com> Date: Thu, 11 Jul 2024 11:26:50 -0400 Subject: [PATCH 10/11] tests should pass --- testing/entrypoints/test_optimize_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/testing/entrypoints/test_optimize_model.py b/testing/entrypoints/test_optimize_model.py index 54c3974db..0f5f1e83b 100644 --- a/testing/entrypoints/test_optimize_model.py +++ b/testing/entrypoints/test_optimize_model.py @@ -36,7 +36,7 @@ "model_path": "model.pth.tar", "config_path": "config.yaml", "output_dir": "output/", - "output_path": "output/", + "output_path": None, }, ), CliCase( From 52b276622c2d9dd0f5b3c5684ec58fcb159aafa2 Mon Sep 17 00:00:00 2001 From: scap3yvt <149599669+scap3yvt@users.noreply.github.com> Date: Fri, 12 Jul 2024 21:44:37 -0400 Subject: [PATCH 11/11] fixed tests --- testing/entrypoints/test_optimize_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/testing/entrypoints/test_optimize_model.py b/testing/entrypoints/test_optimize_model.py index 0f5f1e83b..f2002c72e 100644 --- a/testing/entrypoints/test_optimize_model.py +++ b/testing/entrypoints/test_optimize_model.py @@ -35,8 +35,8 @@ expected_args={ "model_path": "model.pth.tar", "config_path": "config.yaml", - "output_dir": "output/", - "output_path": None, + "output_path": "output/", + "output_dir": None, }, ), CliCase(