Skip to content

Commit

Permalink
Support dict inputs in TFDataLayer, plus some lint fixes (#383)
Browse files Browse the repository at this point in the history
  • Loading branch information
ianstenbit authored and fchollet committed Jun 21, 2023
1 parent 2350e68 commit d955292
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 6 deletions.
4 changes: 1 addition & 3 deletions keras_core/backend/tensorflow/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@ def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):

def categorical(logits, num_samples, dtype="int64", seed=None):
seed = tf_draw_seed(seed)
output = tf.random.stateless_categorical(
logits, num_samples, seed=seed
)
output = tf.random.stateless_categorical(logits, num_samples, seed=seed)
return tf.cast(output, dtype)


Expand Down
5 changes: 4 additions & 1 deletion keras_core/backend/torch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ def categorical(logits, num_samples, dtype="int32", seed=None):
dtype = to_torch_dtype(dtype)
generator = torch_seed_generator(seed, device=get_device())
return torch.multinomial(
logits, num_samples, replacement=True, generator=generator,
logits,
num_samples,
replacement=True,
generator=generator,
).type(dtype)


Expand Down
9 changes: 7 additions & 2 deletions keras_core/layers/preprocessing/tf_data_layer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from tensorflow import nest

from keras_core import backend
from keras_core.layers.layer import Layer
from keras_core.utils import backend_utils
Expand All @@ -22,8 +24,11 @@ def __call__(self, inputs, **kwargs):
):
# We're in a TF graph, e.g. a tf.data pipeline.
self.backend.set_backend("tensorflow")
inputs = self.backend.convert_to_tensor(
inputs, dtype=self.compute_dtype
inputs = nest.map_structure(
lambda x: self.backend.convert_to_tensor(
x, dtype=self.compute_dtype
),
inputs,
)
switch_convert_input_args = False
if self._convert_input_args:
Expand Down

0 comments on commit d955292

Please sign in to comment.