Skip to content

Commit

Permalink
Fix not fully defined shape of gradient accumulators on TF <= 1.9
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaumekln committed Nov 30, 2018
1 parent f6cd590 commit f6641f1
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ OpenNMT-tf follows [semantic versioning 2.0.0](https://semver.org/). The API cov

* Checkpoint utilities now save a relative path instead of absolute in the generated checkpoint state
* Fix error on missing configuration fields that should be optional
* Fix error on gradient accumulation in TensorFlow versions <= 1.9

## [1.14.1](https://github.com/OpenNMT/OpenNMT-tf/releases/tag/v1.14.1) (2018-11-28)

Expand Down
5 changes: 4 additions & 1 deletion opennmt/utils/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,10 @@ def delayed_update(optimizer, grads_and_vars, global_step, accum_count=1):
accum_grads = []
accum_grads_and_vars = []
for grad, var in grads_and_vars:
accum_grad = tf.Variable(tf.zeros_like(grad), trainable=False, collections=[])
accum_grad = tf.Variable(
tf.zeros(var.shape, dtype=grad.dtype),
trainable=False,
collections=[])
accum_grads.append(accum_grad)
accum_grads_and_vars.append((accum_grad, var))

Expand Down

0 comments on commit f6641f1

Please sign in to comment.