Skip to content

Commit

Permalink
test case +black
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-g committed Dec 19, 2020
1 parent 03443f5 commit eb7b1dc
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
2 changes: 1 addition & 1 deletion elegy/losses/sparse_categorical_crossentropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def sparse_categorical_crossentropy(
) -> jnp.ndarray:

n_classes = y_pred.shape[-1]
#fix for a bug in jax<0.2.7 where take_along_axis returns wrong values
# fix for a bug in jax<0.2.7 where take_along_axis returns wrong values
if y_true.dtype in [jnp.int8, jnp.uint8, jnp.int16, jnp.uint16]:
y_true = y_true.astype(jnp.int32)

Expand Down
14 changes: 13 additions & 1 deletion elegy/losses/sparse_categorical_crossentropy_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import elegy


import jax.numpy as jnp
import jax, jax.numpy as jnp
import numpy as np
import tensorflow.keras as keras


#
Expand Down Expand Up @@ -47,3 +49,13 @@ def test_scce_out_of_bounds():
scce = elegy.losses.SparseCategoricalCrossentropy(check_bounds=False)
assert not jnp.isnan(scce(ytrue0, ypred)).any()
assert not jnp.isnan(scce(ytrue1, ypred)).any()


def test_scce_uint8_ytrue():
ypred = np.random.random([2, 256, 256, 10])
ytrue = np.random.randint(0, 10, size=(2, 256, 256)).astype(np.uint8)

loss0 = elegy.losses.sparse_categorical_crossentropy(ytrue, ypred, from_logits=True)
loss1 = keras.losses.sparse_categorical_crossentropy(ytrue, ypred, from_logits=True)

assert np.allclose(loss0, loss1)

0 comments on commit eb7b1dc

Please sign in to comment.