Skip to content

Commit

Permalink
anomaly save_model bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
eunwoosh committed Oct 27, 2022
1 parent 89754c1 commit fd16958
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 8 deletions.
4 changes: 2 additions & 2 deletions external/anomaly/tasks/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def export(self, export_type: ExportType, output_model: ModelEntity) -> None:
output_model.set_data("label_schema.json", label_schema_to_bytes(self.task_environment.label_schema))
self._set_metadata(output_model)

def _model_info(self) -> Dict:
def model_info(self) -> Dict:
"""Return model info to save the model weights.
Returns:
Expand All @@ -282,7 +282,7 @@ def save_model(self, output_model: ModelEntity) -> None:
output_model (ModelEntity): Output model onto which the weights are saved.
"""
logger.info("Saving the model weights.")
model_info = self._model_info()
model_info = self.model_info()
buffer = io.BytesIO()
torch.save(model_info, buffer)
output_model.set_data("weights.pth", buffer.getvalue())
Expand Down
7 changes: 1 addition & 6 deletions ote_cli/ote_cli/utils/hpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,12 +517,7 @@ def __init__(
if _is_anomaly_framework_task(task_type):
impl_class = get_impl_class(environment.model_template.entrypoints.base)
task = impl_class(task_environment=environment)
model = ModelEntity(
dataset,
environment.get_model_configuration(),
)
task.save_model(model)
save_model_data(model, self.work_dir)
torch.save(task.model_info(), osp.join(self.work_dir, "weights.pth"))
else:
save_model_data(environment.model, self.work_dir)

Expand Down

0 comments on commit fd16958

Please sign in to comment.