Skip to content

Commit

Permalink
Fixed torch and jax tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gowthamkpr committed Nov 13, 2024
1 parent 83b66ed commit e4a334d
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
class DifferentialBinarizationOCRTest(TestCase):
def setUp(self):
self.images = ops.ones((2, 32, 32, 3))
self.labels = ops.zeros((2, 32, 32, 4))
self.labels = ops.concatenate(
(ops.zeros((2, 16, 32, 4)), ops.ones((2, 16, 32, 4))), axis=1
)
image_encoder = ResNetBackbone(
input_conv_filters=[4],
input_conv_kernel_sizes=[7],
Expand Down
7 changes: 4 additions & 3 deletions keras_hub/src/models/differential_binarization/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,15 @@ def __call__(self, y_true, y_pred, mask, return_origin=False):
positive_loss = loss * ops.cast(positive, "float32")
negative_loss = loss * ops.cast(negative, "float32")

# hard negative mining, as suggested in the paper:
# compute the threshold for hard negatives, and zero-out
# hard negative mining, as suggested in
# [Real-time Scene Text Detection with Differentiable Binarization](https://arxiv.org/abs/1911.08947):
# Compute the threshold for hard negatives, and zero-out
# negative losses below the threshold. using this approach,
# we achieve efficient computation on GPUs

# compute negative_count relative to the element count of y_pred
negative_count_rel = ops.cast(negative_count, "float32") / ops.prod(
ops.shape(negative_count)
ops.cast(ops.shape(y_pred), "float32")
)
# compute the threshold value for negative losses and zero neg. loss
# values below this threshold
Expand Down
10 changes: 5 additions & 5 deletions keras_hub/src/models/differential_binarization/losses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ def test_loss(self):
mask = np.array([0.0, 1.0, 1.0, 0.0])
weights = np.array([4.0, 5.0, 6.0, 7.0])
loss = self.loss_obj(y_true, y_pred, mask, weights)
self.assertAlmostEqual(loss.numpy(), 0.74358, delta=1e-4)
self.assertAlmostEqual(loss, 0.74358, delta=1e-4)

def test_correct(self):
y_true = np.array([1.0, 1.0, 0.0, 0.0])
y_pred = y_true
mask = np.array([0.0, 1.0, 1.0, 0.0])
loss = self.loss_obj(y_true, y_pred, mask)
self.assertAlmostEqual(loss.numpy(), 0.0, delta=1e-4)
self.assertAlmostEqual(loss, 0.0, delta=1e-4)


class MaskL1LossTest(TestCase):
Expand All @@ -35,7 +35,7 @@ def test_masked(self):
y_pred = np.array([0.1, 0.2, 0.3, 0.4])
mask = np.array([0.0, 1.0, 0.0, 1.0])
loss = self.loss_obj(y_true, y_pred, mask)
self.assertAlmostEqual(loss.numpy(), 2.7, delta=1e-4)
self.assertAlmostEqual(loss, 2.7, delta=1e-4)


class DBLossTest(TestCase):
Expand All @@ -55,7 +55,7 @@ def test_loss(self):
)
y_pred = np.stack((p_map_pred, t_map_pred, b_map_pred), axis=-1)
loss = self.loss_obj(y_true, y_pred)
self.assertAlmostEqual(loss.numpy(), 14.1123, delta=1e-4)
self.assertAlmostEqual(loss, 14.1123, delta=1e-4)

def test_correct(self):
shrink_map = thresh_map = np.array(
Expand All @@ -68,4 +68,4 @@ def test_correct(self):
)
y_pred = np.stack((p_map_pred, t_map_pred, b_map_pred), axis=-1)
loss = self.loss_obj(y_true, y_pred)
self.assertAlmostEqual(loss.numpy(), 0.0, delta=1e-4)
self.assertAlmostEqual(loss, 0.0, delta=1e-4)

0 comments on commit e4a334d

Please sign in to comment.