Skip to content

Commit

Permalink
Merge pull request #896 from mlcommons/new-apis_v0.1.0-dev_optmize-mo…
Browse files Browse the repository at this point in the history
…del_with_output

Added output path for optimize model
  • Loading branch information
sarthakpati authored Jul 13, 2024
2 parents 2c61453 + 79be7b6 commit f125e2d
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 17 deletions.
18 changes: 14 additions & 4 deletions GANDLF/cli/post_training_model_optimization.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
import os
from pathlib import Path
from typing import Optional
from GANDLF.compute import create_pytorch_objects
from GANDLF.config_manager import ConfigManager
from GANDLF.utils import version_check, load_model, optimize_and_save_model


def post_training_model_optimization(model_path: str, config_path: str) -> bool:
def post_training_model_optimization(
model_path: str, config_path: Optional[str] = None, output_dir: Optional[str] = None
) -> bool:
"""
CLI function to optimize a model for deployment.
Args:
model_path (str): Path to the model file.
config_path (str): Path to the config file.
config_path (str, optional): Path to the configuration file.
output_dir (str, optional): Output directory to save the optimized model.
Returns:
bool: True if successful, False otherwise.
Expand All @@ -26,6 +31,12 @@ def post_training_model_optimization(model_path: str, config_path: str) -> bool:
else parameters
)

output_dir = os.path.dirname(model_path) if output_dir is None else output_dir
Path(output_dir).mkdir(parents=True, exist_ok=True)
optimized_model_path = os.path.join(
output_dir, os.path.basename(model_path).replace("pth.tar", "onnx")
)

# Create PyTorch objects and set onnx_export to True for optimization
model, _, _, _, _, parameters = create_pytorch_objects(parameters, device="cpu")
parameters["model"]["onnx_export"] = True
Expand All @@ -35,10 +46,9 @@ def post_training_model_optimization(model_path: str, config_path: str) -> bool:
model.load_state_dict(main_dict["model_state_dict"])

# Optimize the model and save it to an ONNX file
optimize_and_save_model(model, parameters, model_path, onnx_export=True)
optimize_and_save_model(model, parameters, optimized_model_path, onnx_export=True)

# Check if the optimized model file exists
optimized_model_path = model_path.replace("pth.tar", "onnx")
if not os.path.exists(optimized_model_path):
print("Error while optimizing the model.")
return False
Expand Down
33 changes: 28 additions & 5 deletions GANDLF/entrypoints/optimize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@
from GANDLF.entrypoints import append_copyright_to_help


def _optimize_model(model: str, config: Optional[str]):
if post_training_model_optimization(model_path=model, config_path=config):
def _optimize_model(
model: str, config: Optional[str], output_path: Optional[str] = None
):
if post_training_model_optimization(
model_path=model, config_path=config, output_path=output_path
):
print("Post-training model optimization successful.")

Check warning on line 20 in GANDLF/entrypoints/optimize_model.py

View check run for this annotation

Codecov / codecov/patch

GANDLF/entrypoints/optimize_model.py#L20

Added line #L20 was not covered by tests
else:
print("Post-training model optimization failed.")
Expand All @@ -26,6 +30,13 @@ def _optimize_model(model: str, config: Optional[str]):
required=True,
help="Path to the model file (ending in '.pth.tar') you wish to optimize.",
)
@click.option(
"--output-path",
"-o",
type=click.Path(file_okay=False, dir_okay=True),
required=False,
help="Location to save the optimized model, defaults to location of `model`",
)
@click.option(
"--config",
"-c",
Expand All @@ -35,9 +46,11 @@ def _optimize_model(model: str, config: Optional[str]):
type=click.Path(exists=True, file_okay=True, dir_okay=False),
)
@append_copyright_to_help
def new_way(model: str, config: Optional[str]):
def new_way(
model: str, config: Optional[str] = None, output_path: Optional[str] = None
):
"""Generate optimized versions of trained GaNDLF models."""
_optimize_model(model=model, config=config)
_optimize_model(model=model, config=config, output_path=output_path)


# old-fashioned way of running gandlf via `gandlf_optimizeModel`.
Expand All @@ -62,6 +75,16 @@ def old_way():
help="Path to the model file (ending in '.pth.tar') you wish to optimize.",
required=True,
)
parser.add_argument(
"-o",
"--outputdir",
"--output_path",
metavar="",
type=str,
default=None,
help="Location to save the optimized model, defaults to location of `model`",
required=False,
)
parser.add_argument(
"-c",
"--config",
Expand All @@ -74,7 +97,7 @@ def old_way():
)

args = parser.parse_args()
_optimize_model(model=args.model, config=args.config)
_optimize_model(model=args.model, config=args.config, output_path=args.outputdir)


if __name__ == "__main__":
Expand Down
11 changes: 6 additions & 5 deletions GANDLF/utils/modelio.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,18 @@


def optimize_and_save_model(
model: torch.nn.Module, params: dict, path: str, onnx_export: Optional[bool] = True
model: torch.nn.Module,
params: dict,
output_path: str,
onnx_export: Optional[bool] = True,
) -> None:
"""
Perform post-training optimization and save it to a file.
Args:
model (torch.nn.Module): Trained torch model.
params (dict): The parameter dictionary.
path (str): The path to save the model dictionary to.
output_path (str): The path to save the optimized model to.
onnx_export (Optional[bool]): Whether to export to ONNX and OpenVINO. Defaults to True.
"""
# Check if ONNX export is enabled in the parameter dictionary
Expand All @@ -59,9 +62,7 @@ def optimize_and_save_model(
num_channel = params["model"]["num_channels"]
model_dimension = params["model"]["dimension"]
input_shape = params["patch_size"]
onnx_path = path
if not onnx_path.endswith(".onnx"):
onnx_path = onnx_path.replace("pth.tar", "onnx")
onnx_path = output_path.replace(".pth.tar", ".onnx")

if model_dimension == 2:
dummy_input = torch.randn(
Expand Down
34 changes: 32 additions & 2 deletions testing/entrypoints/test_optimize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,28 @@
TmpFile("model.pth.tar", content="123321"),
TmpFile("config.yaml", content="foo: bar"),
TmpNoEx("path_na"),
TmpDire("output/"),
]
test_cases = [
CliCase(
should_succeed=True,
new_way_lines=[
# full command with output
"--model model.pth.tar --config config.yaml --output-path output/",
# tests short arg aliases
"-m model.pth.tar -c config.yaml -o output/",
],
old_way_lines=[
"--model model.pth.tar --config config.yaml --output_path output/",
"-m model.pth.tar -c config.yaml -o output/",
],
expected_args={
"model_path": "model.pth.tar",
"config_path": "config.yaml",
"output_path": "output/",
"output_dir": None,
},
),
CliCase(
should_succeed=True,
new_way_lines=[
Expand All @@ -31,7 +51,12 @@
"--model model.pth.tar --config config.yaml",
"-m model.pth.tar -c config.yaml",
],
expected_args={"model_path": "model.pth.tar", "config_path": "config.yaml"},
expected_args={
"model_path": "model.pth.tar",
"config_path": "config.yaml",
"output_dir": None,
"output_path": None,
},
),
CliCase(
should_succeed=True,
Expand All @@ -40,7 +65,12 @@
"-m model.pth.tar"
],
old_way_lines=["-m model.pth.tar"],
expected_args={"model_path": "model.pth.tar", "config_path": None},
expected_args={
"model_path": "model.pth.tar",
"config_path": None,
"output_path": None,
"output_dir": None,
},
),
CliCase(
should_succeed=False,
Expand Down
4 changes: 3 additions & 1 deletion testing/test_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,9 @@ def test_train_inference_optimize_classification_rad_3d(device):
# file_config_temp = write_temp_config_path(parameters_temp)
model_path = os.path.join(outputDir, all_models_regression[0] + "_best.pth.tar")
config_path = os.path.join(outputDir, "parameters.pkl")
optimization_result = post_training_model_optimization(model_path, config_path)
optimization_result = post_training_model_optimization(
model_path, config_path, outputDir
)
assert optimization_result == True, "Optimization should pass"

## testing inference
Expand Down

0 comments on commit f125e2d

Please sign in to comment.