Skip to content

Commit

Permalink
Fix compatibility of random.beta in tf.data (#19923)
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 authored Jun 26, 2024
1 parent 656df40 commit a2e9a52
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
6 changes: 2 additions & 4 deletions keras/src/backend/tensorflow/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,8 @@ def beta(shape, alpha, beta, dtype=None, seed=None):
# such as for output shape of (2, 3) and alpha shape of (1, 3)
# So to resolve this, we explicitly broadcast alpha and beta to shape before
# passing them to the stateless_gamma function.
if tf.rank(alpha) > 1:
alpha = tf.broadcast_to(alpha, shape)
if tf.rank(beta) > 1:
beta = tf.broadcast_to(beta, shape)
alpha = tf.broadcast_to(alpha, shape)
beta = tf.broadcast_to(beta, shape)

gamma_a = tf.cast(
tf.random.stateless_gamma(
Expand Down
28 changes: 28 additions & 0 deletions keras/src/random/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,34 @@ def test_beta(self, seed, shape, alpha, beta, dtype):
)


class RandomBehaviorTest(testing.TestCase, parameterized.TestCase):
def test_beta_tf_data_compatibility(self):
import tensorflow as tf

from keras.src.layers.preprocessing.tf_data_layer import TFDataLayer

class BetaLayer(TFDataLayer):
def __init__(self, **kwargs):
super().__init__(**kwargs)

def compute_output_shape(self, input_shape):
return input_shape

def call(self, inputs):
noise = self.backend.random.beta(
self.backend.shape(inputs), alpha=0.5, beta=0.5
)
inputs = inputs + noise
return inputs

layer = BetaLayer()
input_data = np.random.random([2, 4, 4, 3])
ds = tf.data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)
for output in ds.take(1):
output = output.numpy()
self.assertEqual(tuple(output.shape), (2, 4, 4, 3))


class RandomDTypeTest(testing.TestCase, parameterized.TestCase):
INT_DTYPES = [x for x in dtypes.INT_TYPES if x != "uint64"]
FLOAT_DTYPES = dtypes.FLOAT_TYPES
Expand Down

0 comments on commit a2e9a52

Please sign in to comment.