Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added output path for optimize model #896

Merged
18 changes: 14 additions & 4 deletions GANDLF/cli/post_training_model_optimization.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
import os
from pathlib import Path
from typing import Optional
from GANDLF.compute import create_pytorch_objects
from GANDLF.config_manager import ConfigManager
from GANDLF.utils import version_check, load_model, optimize_and_save_model


def post_training_model_optimization(model_path: str, config_path: str) -> bool:
def post_training_model_optimization(
model_path: str, config_path: Optional[str] = None, output_dir: Optional[str] = None
) -> bool:
"""
CLI function to optimize a model for deployment.

Args:
model_path (str): Path to the model file.
config_path (str): Path to the config file.
config_path (str, optional): Path to the configuration file.
output_dir (str, optional): Output directory to save the optimized model.

Returns:
bool: True if successful, False otherwise.
Expand All @@ -26,6 +31,12 @@ def post_training_model_optimization(model_path: str, config_path: str) -> bool:
else parameters
)

output_dir = os.path.dirname(model_path) if output_dir is None else output_dir
Path(output_dir).mkdir(parents=True, exist_ok=True)
optimized_model_path = os.path.join(
output_dir, os.path.basename(model_path).replace("pth.tar", "onnx")
)

# Create PyTorch objects and set onnx_export to True for optimization
model, _, _, _, _, parameters = create_pytorch_objects(parameters, device="cpu")
parameters["model"]["onnx_export"] = True
Expand All @@ -35,10 +46,9 @@ def post_training_model_optimization(model_path: str, config_path: str) -> bool:
model.load_state_dict(main_dict["model_state_dict"])

# Optimize the model and save it to an ONNX file
optimize_and_save_model(model, parameters, model_path, onnx_export=True)
optimize_and_save_model(model, parameters, optimized_model_path, onnx_export=True)

# Check if the optimized model file exists
optimized_model_path = model_path.replace("pth.tar", "onnx")
if not os.path.exists(optimized_model_path):
print("Error while optimizing the model.")
return False
Expand Down
33 changes: 28 additions & 5 deletions GANDLF/entrypoints/optimize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@
from GANDLF.entrypoints import append_copyright_to_help


def _optimize_model(model: str, config: Optional[str]):
if post_training_model_optimization(model_path=model, config_path=config):
def _optimize_model(
model: str, config: Optional[str], output_path: Optional[str] = None
):
if post_training_model_optimization(
model_path=model, config_path=config, output_path=output_path
):
print("Post-training model optimization successful.")
else:
print("Post-training model optimization failed.")
Expand All @@ -26,6 +30,13 @@ def _optimize_model(model: str, config: Optional[str]):
required=True,
help="Path to the model file (ending in '.pth.tar') you wish to optimize.",
)
@click.option(
"--output-path",
"-o",
type=click.Path(file_okay=False, dir_okay=True),
required=False,
help="Location to save the optimized model, defaults to location of `model`",
)
@click.option(
"--config",
"-c",
Expand All @@ -35,9 +46,11 @@ def _optimize_model(model: str, config: Optional[str]):
type=click.Path(exists=True, file_okay=True, dir_okay=False),
)
@append_copyright_to_help
def new_way(model: str, config: Optional[str]):
def new_way(
model: str, config: Optional[str] = None, output_path: Optional[str] = None
):
"""Generate optimized versions of trained GaNDLF models."""
_optimize_model(model=model, config=config)
_optimize_model(model=model, config=config, output_path=output_path)


# old-fashioned way of running gandlf via `gandlf_optimizeModel`.
Expand All @@ -62,6 +75,16 @@ def old_way():
help="Path to the model file (ending in '.pth.tar') you wish to optimize.",
required=True,
)
parser.add_argument(
"-o",
"--outputdir",
"--output_path",
metavar="",
type=str,
default=None,
help="Location to save the optimized model, defaults to location of `model`",
required=False,
)
parser.add_argument(
"-c",
"--config",
Expand All @@ -74,7 +97,7 @@ def old_way():
)

args = parser.parse_args()
_optimize_model(model=args.model, config=args.config)
_optimize_model(model=args.model, config=args.config, output_path=args.outputdir)


if __name__ == "__main__":
Expand Down
11 changes: 6 additions & 5 deletions GANDLF/utils/modelio.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,18 @@


def optimize_and_save_model(
model: torch.nn.Module, params: dict, path: str, onnx_export: Optional[bool] = True
model: torch.nn.Module,
params: dict,
output_path: str,
onnx_export: Optional[bool] = True,
) -> None:
"""
Perform post-training optimization and save it to a file.

Args:
model (torch.nn.Module): Trained torch model.
params (dict): The parameter dictionary.
path (str): The path to save the model dictionary to.
output_path (str): The path to save the optimized model to.
onnx_export (Optional[bool]): Whether to export to ONNX and OpenVINO. Defaults to True.
"""
# Check if ONNX export is enabled in the parameter dictionary
Expand All @@ -59,9 +62,7 @@ def optimize_and_save_model(
num_channel = params["model"]["num_channels"]
model_dimension = params["model"]["dimension"]
input_shape = params["patch_size"]
onnx_path = path
if not onnx_path.endswith(".onnx"):
onnx_path = onnx_path.replace("pth.tar", "onnx")
onnx_path = output_path.replace(".pth.tar", ".onnx")

if model_dimension == 2:
dummy_input = torch.randn(
Expand Down
34 changes: 32 additions & 2 deletions testing/entrypoints/test_optimize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,28 @@
TmpFile("model.pth.tar", content="123321"),
TmpFile("config.yaml", content="foo: bar"),
TmpNoEx("path_na"),
TmpDire("output/"),
]
test_cases = [
CliCase(
should_succeed=True,
new_way_lines=[
# full command with output
"--model model.pth.tar --config config.yaml --output-path output/",
# tests short arg aliases
"-m model.pth.tar -c config.yaml -o output/",
],
old_way_lines=[
"--model model.pth.tar --config config.yaml --output_path output/",
"-m model.pth.tar -c config.yaml -o output/",
],
expected_args={
"model_path": "model.pth.tar",
"config_path": "config.yaml",
"output_path": "output/",
"output_dir": None,
},
),
CliCase(
should_succeed=True,
new_way_lines=[
Expand All @@ -31,7 +51,12 @@
"--model model.pth.tar --config config.yaml",
"-m model.pth.tar -c config.yaml",
],
expected_args={"model_path": "model.pth.tar", "config_path": "config.yaml"},
expected_args={
"model_path": "model.pth.tar",
"config_path": "config.yaml",
"output_dir": None,
"output_path": None,
},
),
CliCase(
should_succeed=True,
Expand All @@ -40,7 +65,12 @@
"-m model.pth.tar"
],
old_way_lines=["-m model.pth.tar"],
expected_args={"model_path": "model.pth.tar", "config_path": None},
expected_args={
"model_path": "model.pth.tar",
"config_path": None,
"output_path": None,
"output_dir": None,
},
),
CliCase(
should_succeed=False,
Expand Down
4 changes: 3 additions & 1 deletion testing/test_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,9 @@ def test_train_inference_optimize_classification_rad_3d(device):
# file_config_temp = write_temp_config_path(parameters_temp)
model_path = os.path.join(outputDir, all_models_regression[0] + "_best.pth.tar")
config_path = os.path.join(outputDir, "parameters.pkl")
optimization_result = post_training_model_optimization(model_path, config_path)
optimization_result = post_training_model_optimization(
model_path, config_path, outputDir
)
assert optimization_result == True, "Optimization should pass"

## testing inference
Expand Down
Loading