diff --git a/keras/src/backend/tensorflow/random.py b/keras/src/backend/tensorflow/random.py index 0212610085d..4b6f12c4559 100644 --- a/keras/src/backend/tensorflow/random.py +++ b/keras/src/backend/tensorflow/random.py @@ -167,10 +167,8 @@ def beta(shape, alpha, beta, dtype=None, seed=None): # such as for output shape of (2, 3) and alpha shape of (1, 3) # So to resolve this, we explicitly broadcast alpha and beta to shape before # passing them to the stateless_gamma function. - if tf.rank(alpha) > 1: - alpha = tf.broadcast_to(alpha, shape) - if tf.rank(beta) > 1: - beta = tf.broadcast_to(beta, shape) + alpha = tf.broadcast_to(alpha, shape) + beta = tf.broadcast_to(beta, shape) gamma_a = tf.cast( tf.random.stateless_gamma( diff --git a/keras/src/random/random_test.py b/keras/src/random/random_test.py index a7358edbc25..3bbf6b29aed 100644 --- a/keras/src/random/random_test.py +++ b/keras/src/random/random_test.py @@ -391,6 +391,34 @@ def test_beta(self, seed, shape, alpha, beta, dtype): ) +class RandomBehaviorTest(testing.TestCase, parameterized.TestCase): + def test_beta_tf_data_compatibility(self): + import tensorflow as tf + + from keras.src.layers.preprocessing.tf_data_layer import TFDataLayer + + class BetaLayer(TFDataLayer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def compute_output_shape(self, input_shape): + return input_shape + + def call(self, inputs): + noise = self.backend.random.beta( + self.backend.shape(inputs), alpha=0.5, beta=0.5 + ) + inputs = inputs + noise + return inputs + + layer = BetaLayer() + input_data = np.random.random([2, 4, 4, 3]) + ds = tf.data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + for output in ds.take(1): + output = output.numpy() + self.assertEqual(tuple(output.shape), (2, 4, 4, 3)) + + class RandomDTypeTest(testing.TestCase, parameterized.TestCase): INT_DTYPES = [x for x in dtypes.INT_TYPES if x != "uint64"] FLOAT_DTYPES = dtypes.FLOAT_TYPES