Skip to content
This repository has been archived by the owner on Nov 18, 2023. It is now read-only.

Save model #124

Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion kglib/kgcn/learn/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,24 @@
from kglib.kgcn.learn.loss import loss_ops_preexisting_no_penalty
from kglib.kgcn.learn.metrics import existence_accuracy

from graph_nets import utils_np
from graph_nets.graphs import GraphsTuple


class KGCNLearner:
"""
Responsible for running a KGCN model
"""
def __init__(self, model, num_processing_steps_tr=10, num_processing_steps_ge=10):
def __init__(self, model, num_processing_steps_tr=10, num_processing_steps_ge=10, save_fle="save_model.txt", reload_fle=''):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at this, I think we instead want to have save_file as an argument to the method that does the training, and load_file can be a mandatory argument for the method that does the testing. That way, the user can make a single KGCNLearner and train it multiple times, saving to multiple different files.
Adding load_file to the test method gives us a guarantee for the test method that there is a file provided.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. Adjusted it

"""Args:
save_fle: Name to save the trained model to.
reload_fle: Name to load saved model from, when doing inference.
"""
self._model = model
self._num_processing_steps_tr = num_processing_steps_tr
self._num_processing_steps_ge = num_processing_steps_ge
self.save_fle = save_fle
self.reload_fle = reload_fle

def __call__(self,
tr_input_graphs,
Expand Down Expand Up @@ -102,6 +111,7 @@ def __call__(self,
train_writer = tf.summary.FileWriter(log_dir, sess.graph)

sess.run(tf.global_variables_initializer())
model_saver = tf.train.Saver()

logged_iterations = []
losses_tr = []
Expand Down Expand Up @@ -171,6 +181,51 @@ def __call__(self,
"outputs": output_ops_tr
},
feed_dict=feed_dict)

# Train the model and save it in the end
if not self.save_fle.is_dir():
model_saver.save(sess, self.save_fle.as_posix())
tf.train.write_graph(sess.graph.as_graph_def(), logdir=self.save_fle.parent.as_posix(), name=self.save_fle.with_suffix('.pbtxt').as_posix(), as_text=True)

training_info = logged_iterations, losses_tr, losses_ge, corrects_tr, corrects_ge, solveds_tr, solveds_ge
return train_values, test_values, training_info

# New function to infer / apply without training
# Inspired from: https://medium.com/@prasadpal107/saving-freezing-optimizing-for-inference-restoring-of-tensorflow-models-b4146deb21b5
def infer(self,
input_graphs,
target_graphs, log_dir):

input_ph, target_ph = create_placeholders(input_graphs, target_graphs)
input_ph, target_ph = make_all_runnable_in_session(input_ph, target_ph)
output_ops_ge = self._model(input_ph, self._num_processing_steps_ge)
saver = tf.train.import_meta_graph(self.reload_fle.as_posix() + '.meta')

sess = tf.Session()
sess.run(tf.global_variables_initializer())
tf.reset_default_graph()
with sess.as_default():
if not self.reload_fle.is_dir():
saver.restore(sess, self.reload_fle.as_posix())
else:
print("no file found, restoring failed")

input_graphs_tuple = utils_np.networkxs_to_graphs_tuple(input_graphs)
target_graphs_tuple = utils_np.networkxs_to_graphs_tuple(target_graphs)
feed_dict = {
input_ph: input_graphs_tuple,
target_ph: target_graphs_tuple,
}
test_values = sess.run(
{
"target": target_ph,
"outputs": output_ops_ge,
},
feed_dict=feed_dict)

correct_ge, solved_ge = existence_accuracy(
test_values["target"], test_values["outputs"][-1], use_edges=False)

testing_info = 0, 0, 0, 0, [correct_ge], 0, [solved_ge]

return test_values, testing_info
28 changes: 21 additions & 7 deletions kglib/kgcn/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import networkx as nx
import numpy as np
from pathlib import Path
from graph_nets.utils_np import graphs_tuple_to_networkxs

from kglib.kgcn.learn.learn import KGCNLearner
Expand All @@ -44,7 +45,10 @@ def pipeline(graphs,
attr_embedding_dim=6,
edge_output_size=3,
node_output_size=3,
output_dir=None):
output_dir=None,
do_test=False,
save_fle="test_model.ckpt",
reload_fle=""):

############################################################
# Manipulate the graph data
Expand Down Expand Up @@ -82,18 +86,28 @@ def pipeline(graphs,
node_output_size=node_output_size)

learner = KGCNLearner(kgcn,
num_processing_steps_tr=num_processing_steps_tr,
num_processing_steps_ge=num_processing_steps_ge)

train_values, test_values, tr_info = learner(tr_input_graphs,
num_processing_steps_tr=num_processing_steps_tr, # These processing steps indicate how many message-passing iterations to do for every training / testing step
num_processing_steps_ge=num_processing_steps_ge,
save_fle=f'{output_dir}/{save_fle}',
reload_fle=f'{output_dir}/{reload_fle}')

# only test
if not Path(output_dir / reload_fle).is_dir() and do_test == True:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line currently causes a test to fail with the following error:

======================================================================
ERROR: test_learning_is_done (__main__.TestDiagnosisExample)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/circleci/.cache/bazel/_bazel_circleci/046479231ae58926fd02b6cd39e690ed/execroot/kglib/bazel-out/k8-fastbuild/bin/tests/end_to_end/diagnosis.runfiles/kglib/tests/end_to_end/kgcn/diagnosis.py", line 39, in test_learning_is_done
    solveds_tr, solveds_ge = diagnosis_example()
  File "/home/circleci/.cache/bazel/_bazel_circleci/046479231ae58926fd02b6cd39e690ed/execroot/kglib/bazel-out/k8-fastbuild/bin/tests/end_to_end/diagnosis.runfiles/kglib/kglib/kgcn/examples/diagnosis/diagnosis.py", line 108, in diagnosis_example
    output_dir=f"./events/{time.time()}/")
  File "/home/circleci/.cache/bazel/_bazel_circleci/046479231ae58926fd02b6cd39e690ed/execroot/kglib/bazel-out/k8-fastbuild/bin/tests/end_to_end/diagnosis.runfiles/kglib/kglib/kgcn/pipeline/pipeline.py", line 95, in pipeline
    if not Path(output_dir / reload_fle).is_dir() and do_test == True:
TypeError: unsupported operand type(s) for /: 'str' and 'str'

----------------------------------------------------------------------
Ran 1 test in 255.269s

FAILED (errors=1)

You can see the test here. You can navigate to the tests via the tick or cross that's visible on the right of each commit on the PR. You can also run all of the tests locally on your own machine if you're up and running with bazel.

This line looks like it should be something like Path(f'{output_dir}/{reload_fle}')

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error was caused because we use pathlib for paths.
I think I fixed it. Will push it soon.

test_values, tr_info = learner.test(ge_input_graphs,
ge_target_graphs,
log_dir=output_dir)
# train
else:
train_values, test_values, tr_info = learner(input_graphs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that you've added learner.test(...) it would be ideal to change this method to be learner.train(...), I don't know if this breaks anything though, I can't remember if it will cause an issue with Sonnet or not, but worth a try. That means renaming KGCNlearner's __call__ magic method with train

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds logical.
I adjusted it. We will see what happens.

tr_target_graphs,
ge_input_graphs,
ge_target_graphs,
num_training_iterations=num_training_iterations,
log_dir=output_dir)

plot_across_training(*tr_info, output_file=f'{output_dir}learning.png')
plot_predictions(graphs[tr_ge_split:], test_values, num_processing_steps_ge, output_file=f'{output_dir}graph.png')

plot_across_training(*tr_info, output_file=f'{output_dir}/learning.png')
plot_predictions(graphs[tr_ge_split:], test_values, num_processing_steps_ge, output_file=f'{output_dir}/graph.png')

logit_graphs = graphs_tuple_to_networkxs(test_values["outputs"][-1])

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@ tensorflow-probability==0.7.0
termcolor==1.1.0
Werkzeug==0.15.6
wrapt==1.11.2