Skip to content

Commit

Permalink
Bug Fix for CenterNetBoxLoss (#2432)
Browse files Browse the repository at this point in the history
* fix CenterNetBoxLoss, add Testcase

* cleanup testcase

* test case code clean up

* fix Pytorch pipeline

* make test_heading_regression_loss framework agnostic

---------

Co-authored-by: Till Beemelmanns <[email protected]>
  • Loading branch information
TillBeemelmanns and Till Beemelmanns authored Nov 2, 2024
1 parent dee0634 commit ac15b94
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
8 changes: 4 additions & 4 deletions keras_cv/src/losses/centernet_box_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,14 @@ def __init__(self, num_heading_bins, anchor_size, **kwargs):
self.anchor_size = anchor_size

def heading_regression_loss(self, heading_true, heading_pred):
heading_pred = ops.convert_to_tensor(heading_pred)

# Set the heading to within 0 -> 2pi
heading_true = ops.floor(ops.mod(heading_true, 2 * math.pi))
heading_true = ops.mod(heading_true, 2 * math.pi)

# Divide 2pi into bins. shifted by 0.5 * angle_per_class.
angle_per_class = (2 * math.pi) / self.num_heading_bins
shift_angle = ops.floor(
ops.mod(heading_true + angle_per_class / 2, 2 * math.pi)
)
shift_angle = ops.mod(heading_true + angle_per_class / 2, 2 * math.pi)

heading_bin_label_float = ops.floor(
ops.divide(shift_angle, angle_per_class)
Expand Down
26 changes: 26 additions & 0 deletions keras_cv/src/losses/centernet_box_loss_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from absl.testing import parameterized

import keras_cv
from keras_cv.src.backend import ops
from keras_cv.src.tests.test_case import TestCase


Expand All @@ -42,3 +43,28 @@ def test_proper_output_shapes(self, reduction, target_size):
y_pred=np.random.uniform(size=(2, 10, 6 + 2 * 4)),
)
self.assertEqual(result.shape, target_size)

def test_heading_regression_loss(self):
num_heading_bins = 4
loss = keras_cv.losses.CenterNetBoxLoss(
num_heading_bins=num_heading_bins, anchor_size=[1.0, 1.0, 1.0]
)
heading_true = np.array(
[[0, (1 / 2.0) * np.pi, np.pi, (3.0 / 2.0) * np.pi]]
)
heading_pred = np.array(
[
[
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
]
]
)
heading_loss = loss.heading_regression_loss(
heading_true=heading_true, heading_pred=heading_pred
)
ce_loss = -np.log(np.exp(1) / np.exp([1, 0, 0, 0]).sum())
expected_loss = ce_loss * num_heading_bins
self.assertAllClose(ops.sum(heading_loss), expected_loss)

0 comments on commit ac15b94

Please sign in to comment.