-
Notifications
You must be signed in to change notification settings - Fork 93
Save model #124
Save model #124
Conversation
Save model after training and be able to reload it for inference in the infer function. Enables using the model on a test set.
Adjusted pipeline.py to be able to save and restore trained models.
kglib/kgcn/pipeline/pipeline.py
Outdated
log_dir=output_dir) | ||
# train | ||
else: | ||
train_values, test_values, tr_info = learner(input_graphs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that you've added learner.test(...)
it would be ideal to change this method to be learner.train(...)
, I don't know if this breaks anything though, I can't remember if it will cause an issue with Sonnet or not, but worth a try. That means renaming KGCNlearner
's __call__
magic method with train
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds logical.
I adjusted it. We will see what happens.
kglib/kgcn/learn/learn.py
Outdated
|
||
class KGCNLearner: | ||
""" | ||
Responsible for running a KGCN model | ||
""" | ||
def __init__(self, model, num_processing_steps_tr=10, num_processing_steps_ge=10): | ||
def __init__(self, model, num_processing_steps_tr=10, num_processing_steps_ge=10, save_fle="save_model.txt", reload_fle=''): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at this, I think we instead want to have save_file
as an argument to the method that does the training, and load_file
can be a mandatory argument for the method that does the testing. That way, the user can make a single KGCNLearner and train it multiple times, saving to multiple different files.
Adding load_file
to the test method gives us a guarantee for the test method that there is a file provided.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree. Adjusted it
kglib/kgcn/pipeline/pipeline.py
Outdated
reload_fle=f'{output_dir}/{reload_fle}') | ||
|
||
# only test | ||
if not Path(output_dir / reload_fle).is_dir() and do_test == True: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line currently causes a test to fail with the following error:
======================================================================
ERROR: test_learning_is_done (__main__.TestDiagnosisExample)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/home/circleci/.cache/bazel/_bazel_circleci/046479231ae58926fd02b6cd39e690ed/execroot/kglib/bazel-out/k8-fastbuild/bin/tests/end_to_end/diagnosis.runfiles/kglib/tests/end_to_end/kgcn/diagnosis.py", line 39, in test_learning_is_done
solveds_tr, solveds_ge = diagnosis_example()
File "/home/circleci/.cache/bazel/_bazel_circleci/046479231ae58926fd02b6cd39e690ed/execroot/kglib/bazel-out/k8-fastbuild/bin/tests/end_to_end/diagnosis.runfiles/kglib/kglib/kgcn/examples/diagnosis/diagnosis.py", line 108, in diagnosis_example
output_dir=f"./events/{time.time()}/")
File "/home/circleci/.cache/bazel/_bazel_circleci/046479231ae58926fd02b6cd39e690ed/execroot/kglib/bazel-out/k8-fastbuild/bin/tests/end_to_end/diagnosis.runfiles/kglib/kglib/kgcn/pipeline/pipeline.py", line 95, in pipeline
if not Path(output_dir / reload_fle).is_dir() and do_test == True:
TypeError: unsupported operand type(s) for /: 'str' and 'str'
----------------------------------------------------------------------
Ran 1 test in 255.269s
FAILED (errors=1)
You can see the test here. You can navigate to the tests via the tick or cross that's visible on the right of each commit on the PR. You can also run all of the tests locally on your own machine if you're up and running with bazel.
This line looks like it should be something like Path(f'{output_dir}/{reload_fle}')
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Error was caused because we use pathlib for paths.
I think I fixed it. Will push it soon.
Further work has been done on this on my branch here |
We have radically changed the repo since this point and deprecated the KGCN, so I'm closing this in favour of #161. |
What is the goal of this PR?
Goal: To be able to save, reload and do inference on trained KGCN models. This is valuable, because after training a useful KGCN model, you would like to be able to run it on other data(sets).
This PR adds functionality to save the trained model in tensorflow. To reload the model after it is saved and to have an infer function, which can be used for new data.
What are the changes implemented in this PR?
Three things are implemented: