Skip to content

Commit

Permalink
Fixing detection saliency map for one class case (#2368)
Browse files Browse the repository at this point in the history
* fix softmax

* fix validity tests
  • Loading branch information
negvet authored Jul 17, 2023
1 parent 895bd36 commit 3d157ab
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def func(

# Don't use softmax for tiles in tiling detection, if the tile doesn't contain objects,
# it would highlight one of the class maps as a background class
if self.use_cls_softmax:
if self.use_cls_softmax and self._num_cls_out_channels > 1:
cls_scores = [torch.softmax(t, dim=1) for t in cls_scores]

batch_size, _, height, width = cls_scores[-1].size()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,6 @@ def test_saliency_map_cls(self, template):
assert len(saliency_maps) == 2
assert saliency_maps[0].ndim == 3
assert saliency_maps[0].shape == (1000, 7, 7)
assert np.all(np.abs(saliency_maps[0][0][0] - self.ref_saliency_vals_cls[template.name]) <= 1)
actual_sal_vals = saliency_maps[0][0][0].astype(np.int8)
ref_sal_vals = self.ref_saliency_vals_cls[template.name].astype(np.int8)
assert np.all(np.abs(actual_sal_vals - ref_sal_vals) <= 1)
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def test_saliency_map_det(self, template):
assert len(saliency_maps) == 2
assert saliency_maps[0].ndim == 3
assert saliency_maps[0].shape == self.ref_saliency_shapes[template.name]
assert np.all(np.abs(saliency_maps[0][0][0] - self.ref_saliency_vals_det[template.name]) <= 1)
actual_sal_vals = saliency_maps[0][0][0].astype(np.int8)
ref_sal_vals = self.ref_saliency_vals_det[template.name].astype(np.int8)
assert np.all(np.abs(actual_sal_vals - ref_sal_vals) <= 1)

@e2e_pytest_unit
@pytest.mark.parametrize("template", templates_det, ids=templates_det_ids)
Expand Down

0 comments on commit 3d157ab

Please sign in to comment.