Skip to content

Commit

Permalink
Do not return datasets from input_fn for TensorFlow 1.4 compatibility (
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaumekln authored Mar 11, 2019
1 parent dbf3b26 commit 6a34f95
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ OpenNMT-tf follows [semantic versioning 2.0.0](https://semver.org/). The API cov

### Fixes and improvements

* Fix compatibility issue with legacy TensorFlow 1.4

## [1.21.4](https://github.com/OpenNMT/OpenNMT-tf/releases/tag/v1.21.4) (2019-03-07)

### Fixes and improvements
Expand Down
9 changes: 6 additions & 3 deletions opennmt/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,8 @@ def _build_train_spec(self, checkpoint_path):
num_shards=self._hvd.size() if self._hvd is not None else 1,
shard_index=self._hvd.rank() if self._hvd is not None else 0,
num_threads=self._config["train"].get("num_threads"),
prefetch_buffer_size=self._config["train"].get("prefetch_buffer_size")),
prefetch_buffer_size=self._config["train"].get("prefetch_buffer_size"),
return_dataset=False),
max_steps=train_steps,
hooks=train_hooks)
return train_spec
Expand All @@ -261,7 +262,8 @@ def _build_eval_spec(self):
features_file=self._config["data"]["eval_features_file"],
labels_file=self._config["data"].get("eval_labels_file"),
num_threads=self._config["eval"].get("num_threads"),
prefetch_buffer_size=self._config["eval"].get("prefetch_buffer_size")),
prefetch_buffer_size=self._config["eval"].get("prefetch_buffer_size"),
return_dataset=False),
steps=None,
exporters=_make_exporters(
self._config["eval"]["exporters"],
Expand Down Expand Up @@ -388,7 +390,8 @@ def infer(self,
features_file=features_file,
bucket_width=self._config["infer"]["bucket_width"],
num_threads=self._config["infer"].get("num_threads"),
prefetch_buffer_size=self._config["infer"].get("prefetch_buffer_size"))
prefetch_buffer_size=self._config["infer"].get("prefetch_buffer_size"),
return_dataset=False)

if predictions_file:
stream = io.open(predictions_file, encoding="utf-8", mode="w")
Expand Down

0 comments on commit 6a34f95

Please sign in to comment.