Skip to content

Commit

Permalink
Merge pull request dreoporto#6 from dreoporto/caching-save-history
Browse files Browse the repository at this point in the history
Caching save history
  • Loading branch information
dreoporto authored Sep 12, 2023
2 parents dfc1072 + 7c3af84 commit 7e15522
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 68 deletions.
9 changes: 8 additions & 1 deletion ptmlib/examples/computer_vision_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from tensorflow.keras import layers

import ptmlib.model_tools as modt

import ptmlib.charts as pch

class MyCallback(keras.callbacks.Callback):

Expand Down Expand Up @@ -103,6 +103,13 @@ def main():
print(test_labels[0])
print(max(classifications[0]))

# ensure history data is still available, even if cached
pch.show_history_chart(history, "accuracy") # render again to be sure we have proper history data
print('type(model):', type(model))
print('type(history):', type(history))
print('history.history:', history.history)
print('history.params:', history.params)


if __name__ == '__main__':
main()
29 changes: 27 additions & 2 deletions ptmlib/model_tools.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import pickle
from typing import Any, List

import matplotlib.image as mpimg
Expand All @@ -8,6 +9,8 @@
import ptmlib.charts as pch
from ptmlib.time import Stopwatch

HISTORY_FILE_SUFFIX_EXTENSION = '_history.pkl'


def get_file_path(model_file_name: str, model_file_format: str = ""):
extension = _get_model_file_extension(model_file_format)
Expand Down Expand Up @@ -36,13 +39,12 @@ def load_or_fit_model(model: Any, model_file_name: str, x: Any, y: Any = None, v
model_file_format: str = "",
load_model_function=_default_load_model_function,
fit_model_function=_default_fit_model_function):
history = None

file_extension = _get_model_file_extension(model_file_format)

if os.path.exists(f'{model_file_name}{file_extension}'):
print(f'Loading existing model file: {model_file_name}{file_extension}')
model = load_model_function(model_file_name, model_file_format)
history = load_history_data(model_file_name)
if images_enabled:
_show_saved_images(metrics, model_file_name, fig_size)
else:
Expand All @@ -52,12 +54,35 @@ def load_or_fit_model(model: Any, model_file_name: str, x: Any, y: Any = None, v
stopwatch.stop()
print(f'Saving new model file: {model_file_name}{file_extension}')
model.save(f'{model_file_name}{file_extension}')
save_history_data(history, model_file_name)
if images_enabled:
_show_new_images(history, model_file_name, metrics)

return model, history


def save_history_data(history: Any, model_file_name: str):
with open(f'{model_file_name}{HISTORY_FILE_SUFFIX_EXTENSION}', 'wb') as history_file:
history_params_tuple = (history.history, history.params)
pickle.dump(history_params_tuple, history_file)


def load_history_data(model_file_name: str):
if not os.path.exists(f'{model_file_name}{HISTORY_FILE_SUFFIX_EXTENSION}'):
return None

# create new history object for return value
history = keras.callbacks.History()

with open(f'{model_file_name}{HISTORY_FILE_SUFFIX_EXTENSION}', 'rb') as history_file:
# load from previously saved history_params_tuple
h, p = pickle.load(history_file)
history.history = h
history.params = p

return history


def _show_new_images(history: Any, model_file_name: str, metrics: List[str]):
if metrics is not None:
for metric in metrics:
Expand Down
Loading

0 comments on commit 7e15522

Please sign in to comment.