Skip to content

Commit

Permalink
Additional tests on delayed update (#281)
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaumekln authored Nov 30, 2018
1 parent ff38e89 commit a79e972
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions opennmt/tests/optim_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,39 @@ def _check_step(grad, expected_variable, expected_step):
_check_step([0.0, -3.0], [-5.0, -2.0], 1) # accum_grad = [-3.0, -2.0]
_check_step([2.0, -1.0], [-4.0, 1.0], 2) # accum_grad = [-1.0, -3.0], apply

def testDelayedUpdateSparseGradients(self):
# Test that delayed update does not crash on sparse gradients.
global_step = tf.Variable(0, trainable=False, dtype=tf.int64)
optimizer = tf.train.GradientDescentOptimizer(1.0)
embeddings = tf.Variable([[1.0, 2.0], [3.0, 4.0]])
x = tf.nn.embedding_lookup(embeddings, [0])
loss = tf.losses.mean_squared_error([[1.1, 2.1]], x)
gradients = optimizer.compute_gradients(loss)
_ = optim.delayed_update(
optimizer,
gradients,
global_step,
accum_count=3)

def testDelayedUpdateOptimizerSlots(self):
# Test that delayed update does not change any variable names, in particular
# optimizer variables.
def _create_variables(accum_count):
global_step = tf.Variable(0, trainable=False, dtype=tf.int64)
optimizer = tf.train.AdamOptimizer(1.0)
gradient = tf.placeholder(tf.float32, shape=[2])
variable = tf.Variable([1.0, 2.0])
optim.delayed_update(
optimizer,
[(gradient, variable)],
global_step,
accum_count=accum_count)
return list(sorted(var.name for var in tf.global_variables()))

vars_no_accum = _create_variables(accum_count=1)
tf.reset_default_graph()
vars_accum = _create_variables(accum_count=3)
self.assertListEqual(vars_accum, vars_no_accum)

if __name__ == "__main__":
tf.test.main()

0 comments on commit a79e972

Please sign in to comment.