Skip to content
This repository has been archived by the owner on Nov 18, 2023. It is now read-only.

Commit

Permalink
Refactor Classifiers (#64)
Browse files Browse the repository at this point in the history
We have removed repeated code among the supervised classifiers
  • Loading branch information
jmsfltchr authored May 29, 2019
1 parent 7a8b0e2 commit 352aef1
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 209 deletions.
2 changes: 1 addition & 1 deletion examples/kgcn/animal_trade/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ The [main](../../../examples/kgcn/animal_trade/main.py) function will:

- Search Grakn for the k-hop neighbours of the selected examples, and store information about them as arrays, denoted in the code as `context_arrays`. This data is saved to file so that subsequent steps can be re-run without recomputing these data

- Build the TensorFlow computation graph using `model.KGCN`, including a multi-class classification step and learning procedure defined by `classify.SupervisedKGCNClassifier`
- Build the TensorFlow computation graph using `model.KGCN`, including a multi-class classification step and learning procedure defined by `classify.SupervisedKGCNMultiClassSingleLabelClassifier`

- Feed the `context_arrays` to the TensorFlow graph, and perform learning

Expand Down
4 changes: 2 additions & 2 deletions examples/kgcn/animal_trade/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def main(modes=(TRAIN, EVAL, PREDICT)):
neighbour_sampling_limit_factor=4)

optimizer = tf.train.GradientDescentOptimizer(learning_rate=FLAGS.learning_rate)
classifier = classify.SupervisedKGCNClassifier(kgcn, optimizer, FLAGS.num_classes, FLAGS.log_dir,
max_training_steps=FLAGS.max_training_steps)
classifier = classify.SupervisedKGCNMultiClassSingleLabelClassifier(kgcn, optimizer, FLAGS.num_classes, FLAGS.log_dir,
max_training_steps=FLAGS.max_training_steps)

feed_dicts = {}
feed_dict_storer = persistence.FeedDictStorer(BASE_PATH + 'input/')
Expand Down
69 changes: 66 additions & 3 deletions examples/kgcn/animal_trade/test/end_to_end_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,76 @@

class TestEndToEnd(unittest.TestCase):

def test_end_to_end(self):
@classmethod
def setUpClass(cls):

# Unzip the Grakn distribution containing our data
sub.run(['unzip', 'external/animaltrade_dist/file/downloaded', '-d',
'external/animaltrade_dist/file/downloaded-unzipped'])

# Start Grakn
sub.run(['external/animaltrade_dist/file/downloaded-unzipped/grakn-core-all-mac-animaltrade1.5.3/grakn', 'server', 'start'])

def test_multi_class_single_label_classification_end_to_end(self):
tf.reset_default_graph()

modes = (TRAIN, EVAL)

client = grakn.client.GraknClient(uri=URI)
sessions = server_mgmt.get_sessions(client, KEYSPACES)
transactions = server_mgmt.get_transactions(sessions)

batch_size = NUM_PER_CLASS * FLAGS.num_classes
kgcn = model.KGCN(NEIGHBOUR_SAMPLE_SIZES,
FLAGS.features_size,
FLAGS.starting_concepts_features_size,
FLAGS.aggregated_size,
FLAGS.embedding_size,
transactions[TRAIN],
batch_size,
neighbour_sampling_method=random_sampling.random_sample,
neighbour_sampling_limit_factor=4)

optimizer = tf.train.GradientDescentOptimizer(learning_rate=FLAGS.learning_rate)
classifier = classify.SupervisedKGCNMultiClassSingleLabelClassifier(kgcn, optimizer, FLAGS.num_classes, None,
max_training_steps=FLAGS.max_training_steps)

feed_dicts = {}

sampling_params = {
TRAIN: {'sample_size': NUM_PER_CLASS, 'population_size': POPULATION_SIZE_PER_CLASS},
EVAL: {'sample_size': NUM_PER_CLASS, 'population_size': POPULATION_SIZE_PER_CLASS},
PREDICT: {'sample_size': NUM_PER_CLASS, 'population_size': POPULATION_SIZE_PER_CLASS},
}
concepts, labels = thing_mgmt.compile_labelled_concepts(EXAMPLES_QUERY, EXAMPLE_CONCEPT_TYPE,
LABEL_ATTRIBUTE_TYPE, ATTRIBUTE_VALUES,
transactions[TRAIN], transactions[PREDICT],
sampling_params)

for mode in modes:
mode_labels = labels[mode]
feed_dicts[mode] = classifier.get_feed_dict(sessions[mode], concepts[mode], labels=mode_labels)

# Note: The ground-truth attribute labels haven't been removed from Grakn, so the results found here are
# invalid, and used as an end-to-end test only

# Train
if TRAIN in modes:
print("\n\n********** TRAIN Keyspace **********")
classifier.train(feed_dicts[TRAIN])

# Eval
if EVAL in modes:
print("\n\n********** EVAL Keyspace **********")
# Presently, eval keyspace is the same as the TRAIN keyspace
classifier.eval(feed_dicts[EVAL])

server_mgmt.close(sessions)
server_mgmt.close(transactions)

def test_multi_class_multi_label_classification_end_to_end(self):
tf.reset_default_graph()

modes = (TRAIN, EVAL)

client = grakn.client.GraknClient(uri=URI)
Expand All @@ -102,7 +164,7 @@ def test_end_to_end(self):
neighbour_sampling_limit_factor=4)

optimizer = tf.train.GradientDescentOptimizer(learning_rate=FLAGS.learning_rate)
classifier = classify.SupervisedKGCNClassifier(kgcn, optimizer, FLAGS.num_classes, None,
classifier = classify.SupervisedKGCNMultiClassMultiLabelClassifier(kgcn, optimizer, FLAGS.num_classes, None,
max_training_steps=FLAGS.max_training_steps)

feed_dicts = {}
Expand All @@ -118,7 +180,8 @@ def test_end_to_end(self):
sampling_params)

for mode in modes:
feed_dicts[mode] = classifier.get_feed_dict(sessions[mode], concepts[mode], labels=labels[mode])
mode_labels = [[0, 1, 1]] * len(labels[mode])
feed_dicts[mode] = classifier.get_feed_dict(sessions[mode], concepts[mode], labels=mode_labels)

# Note: The ground-truth attribute labels haven't been removed from Grakn, so the results found here are
# invalid, and used as an end-to-end test only
Expand Down
2 changes: 1 addition & 1 deletion kglib/kgcn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ kgcn = model.KGCN(neighbour_sample_sizes,

optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)

classifier = classify.SupervisedKGCNClassifier(kgcn,
classifier = classify.SupervisedKGCNMultiClassSingleLabelClassifier(kgcn,
optimizer,
num_classes,
log_dir,
Expand Down
Loading

0 comments on commit 352aef1

Please sign in to comment.