Skip to content

Commit

Permalink
Sync OSS keras to head.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 370982115
  • Loading branch information
qlzh727 authored and tensorflower-gardener committed Apr 28, 2021
1 parent bd968bf commit 2bf1a9d
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 6 deletions.
7 changes: 5 additions & 2 deletions keras/distribute/dataset_creator_model_fit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,11 @@ def testModelFitWithValidationData(self, strategy):

def testModelFitWithDatasetInstance(self, strategy):
with self.assertRaisesRegex(
NotImplementedError, "Only `DatasetCreator` input is supported in "
"`ParameterServerStrategy` at this time."):
NotImplementedError,
"Only `tf.keras.utils.experimental.DatasetCreator` input is supported "
"with `ParameterServerStrategy` at this time. Please see "
"`tf.keras.utils.experimental.DatasetCreator` class docstring for "
"more information."):
self._model_fit(
strategy, x=tf.data.Dataset.from_tensor_slices([1, 1]))

Expand Down
9 changes: 7 additions & 2 deletions keras/engine/data_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1319,8 +1319,13 @@ class _ClusterCoordinatorDataHandler(DataHandler):

def _verify_data_adapter_compatibility(self, adapter_cls):
if adapter_cls != DatasetCreatorAdapter:
raise NotImplementedError("Only `DatasetCreator` input is supported in "
"`ParameterServerStrategy` at this time.")
# TODO(b/186414920): Update the error message once `DatasetCreator` is no
# longer experimental.
raise NotImplementedError(
"Only `tf.keras.utils.experimental.DatasetCreator` input is "
"supported with `ParameterServerStrategy` at this time. Please see "
"`tf.keras.utils.experimental.DatasetCreator` class docstring for "
"more information.")

def _configure_dataset_and_inferred_steps(self, strategy, x, steps_per_epoch,
class_weight, distribute):
Expand Down
4 changes: 2 additions & 2 deletions keras/utils/dataset_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def dataset_fn(input_context):
input_options = tf.distribute.InputOptions(
experimental_fetch_to_device=True,
experimental_per_replica_buffer_size=2)
model.fit(DatasetCreator(dataset_fn, input_options=input_options),
epochs=10, steps_per_epoch=10)
model.fit(tf.keras.utils.experimental.DatasetCreator(
dataset_fn, input_options=input_options), epochs=10, steps_per_epoch=10)
```
`Model.fit` usage with `DatasetCreator` is intended to work across all
Expand Down

0 comments on commit 2bf1a9d

Please sign in to comment.