-
Notifications
You must be signed in to change notification settings - Fork 93
Save model #124
Save model #124
Changes from 6 commits
d1ac7dd
7720b75
f50aa66
5b45164
1c96330
e50c55d
b58a4fc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ | |
|
||
import networkx as nx | ||
import numpy as np | ||
from pathlib import Path | ||
from graph_nets.utils_np import graphs_tuple_to_networkxs | ||
|
||
from kglib.kgcn.learn.learn import KGCNLearner | ||
|
@@ -44,7 +45,10 @@ def pipeline(graphs, | |
attr_embedding_dim=6, | ||
edge_output_size=3, | ||
node_output_size=3, | ||
output_dir=None): | ||
output_dir=None, | ||
do_test=False, | ||
save_fle="test_model.ckpt", | ||
reload_fle=""): | ||
|
||
############################################################ | ||
# Manipulate the graph data | ||
|
@@ -82,18 +86,28 @@ def pipeline(graphs, | |
node_output_size=node_output_size) | ||
|
||
learner = KGCNLearner(kgcn, | ||
num_processing_steps_tr=num_processing_steps_tr, | ||
num_processing_steps_ge=num_processing_steps_ge) | ||
|
||
train_values, test_values, tr_info = learner(tr_input_graphs, | ||
num_processing_steps_tr=num_processing_steps_tr, # These processing steps indicate how many message-passing iterations to do for every training / testing step | ||
num_processing_steps_ge=num_processing_steps_ge, | ||
save_fle=f'{output_dir}/{save_fle}', | ||
reload_fle=f'{output_dir}/{reload_fle}') | ||
|
||
# only test | ||
if not Path(output_dir / reload_fle).is_dir() and do_test == True: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line currently causes a test to fail with the following error:
You can see the test here. You can navigate to the tests via the tick or cross that's visible on the right of each commit on the PR. You can also run all of the tests locally on your own machine if you're up and running with bazel. This line looks like it should be something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Error was caused because we use pathlib for paths. |
||
test_values, tr_info = learner.test(ge_input_graphs, | ||
ge_target_graphs, | ||
log_dir=output_dir) | ||
# train | ||
else: | ||
train_values, test_values, tr_info = learner(input_graphs, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now that you've added There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds logical. |
||
tr_target_graphs, | ||
ge_input_graphs, | ||
ge_target_graphs, | ||
num_training_iterations=num_training_iterations, | ||
log_dir=output_dir) | ||
|
||
plot_across_training(*tr_info, output_file=f'{output_dir}learning.png') | ||
plot_predictions(graphs[tr_ge_split:], test_values, num_processing_steps_ge, output_file=f'{output_dir}graph.png') | ||
|
||
plot_across_training(*tr_info, output_file=f'{output_dir}/learning.png') | ||
plot_predictions(graphs[tr_ge_split:], test_values, num_processing_steps_ge, output_file=f'{output_dir}/graph.png') | ||
|
||
logit_graphs = graphs_tuple_to_networkxs(test_values["outputs"][-1]) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,3 +33,4 @@ tensorflow-probability==0.7.0 | |
termcolor==1.1.0 | ||
Werkzeug==0.15.6 | ||
wrapt==1.11.2 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at this, I think we instead want to have
save_file
as an argument to the method that does the training, andload_file
can be a mandatory argument for the method that does the testing. That way, the user can make a single KGCNLearner and train it multiple times, saving to multiple different files.Adding
load_file
to the test method gives us a guarantee for the test method that there is a file provided.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree. Adjusted it