Skip to content

Commit

Permalink
Support TensorFlow metric in TFOptimizer (intel#1593)
Browse files Browse the repository at this point in the history
* test

* support stateless metric

* support metric in TFOptimizer

* fix tests

* fix style

* add tests
  • Loading branch information
yangw1234 authored Sep 2, 2019
1 parent 7880884 commit 95563ee
Showing 1 changed file with 0 additions and 27 deletions.
27 changes: 0 additions & 27 deletions test_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,33 +84,6 @@ def test_init_tfnet_from_session(self):
self.assert_allclose(output_value, output_value_ref)
self.assert_allclose(grad_input_value, grad_input_value_ref)

def test_tf_optimizer_with_sparse_gradient(self):
import tensorflow as tf

ids = np.random.randint(0, 10, size=[40])
labels = np.random.randint(0, 5, size=[40])
id_rdd = self.sc.parallelize(ids)
label_rdd = self.sc.parallelize(labels)
training_rdd = id_rdd.zip(label_rdd).map(lambda x: [x[0], x[1]])
with tf.Graph().as_default():
dataset = TFDataset.from_rdd(training_rdd,
names=["ids", "labels"],
shapes=[[], []],
types=[tf.int32, tf.int32],
batch_size=8)
id_tensor, label_tensor = dataset.tensors
embedding_table = tf.get_variable(
name="word_embedding",
shape=[10, 5])

embedding = tf.nn.embedding_lookup(embedding_table, id_tensor)
loss = tf.reduce_mean(tf.losses.
sparse_softmax_cross_entropy(logits=embedding,
labels=label_tensor))
optimizer = TFOptimizer(loss, Adam(1e-3))
optimizer.optimize(end_trigger=MaxEpoch(1))
optimizer.sess.close()

def test_tf_net_predict(self):
resource_path = os.path.join(os.path.split(__file__)[0], "../../resources")
tfnet_path = os.path.join(resource_path, "tfnet")
Expand Down

0 comments on commit 95563ee

Please sign in to comment.