From 8179a13ddd9be6e9437e91bc99c1e66ea4782bbf Mon Sep 17 00:00:00 2001 From: Brian Ko Date: Mon, 21 Dec 2020 10:49:52 +0900 Subject: [PATCH 1/7] Implement IoU metric and IoU loss --- pl_bolts/losses/object_detection.py | 28 ++++++++++++++++++++++-- pl_bolts/metrics/object_detection.py | 32 ++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/pl_bolts/losses/object_detection.py b/pl_bolts/losses/object_detection.py index 81d0404813..ccbba5e707 100644 --- a/pl_bolts/losses/object_detection.py +++ b/pl_bolts/losses/object_detection.py @@ -1,10 +1,34 @@ """ -Generalized Intersection over Union (GIoU) loss (Rezatofighi et. al) +Loss functions for Object Detection task """ import torch -from pl_bolts.metrics.object_detection import giou +from pl_bolts.metrics.object_detection import giou, iou + + +def iou_loss(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Calculates the intersection over union loss. + + Args: + preds: batch of prediction bounding boxes with representation ``[x_min, y_min, x_max, y_max]`` + target: batch of target bounding boxes with representation ``[x_min, y_min, x_max, y_max]`` + + Example: + + >>> import torch + >>> from pl_bolts.losses.object_detection import iou_loss + >>> preds = torch.tensor([[100, 100, 200, 200]]) + >>> target = torch.tensor([[150, 150, 250, 250]]) + >>> iou_loss(preds, target) + tensor([[0.8571]]) + + Returns: + IoU loss + """ + loss = 1 - iou(preds, target) + return loss def giou_loss(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: diff --git a/pl_bolts/metrics/object_detection.py b/pl_bolts/metrics/object_detection.py index 3175f3ce24..e2da6a5eaa 100644 --- a/pl_bolts/metrics/object_detection.py +++ b/pl_bolts/metrics/object_detection.py @@ -1,6 +1,38 @@ import torch +def iou(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Calculates the intersection over union. + + Args: + preds: batch of prediction bounding boxes with representation ``[x_min, y_min, x_max, y_max]`` + target: batch of target bounding boxes with representation ``[x_min, y_min, x_max, y_max]`` + + Example: + + >>> import torch + >>> from pl_bolts.metrics.object_detection import iou + >>> preds = torch.tensor([[100, 100, 200, 200]]) + >>> target = torch.tensor([[150, 150, 250, 250]]) + >>> iou(preds, target) + tensor([[0.1429]]) + + Returns: + IoU value + """ + x_min = torch.max(preds[:, None, 0], target[:, 0]) + y_min = torch.max(preds[:, None, 1], target[:, 1]) + x_max = torch.min(preds[:, None, 2], target[:, 2]) + y_max = torch.min(preds[:, None, 3], target[:, 3]) + intersection = (x_max - x_min).clamp(min=0) * (y_max - y_min).clamp(min=0) + pred_area = (preds[:, 2] - preds[:, 0]) * (preds[:, 3] - preds[:, 1]) + target_area = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1]) + union = pred_area[:, None] + target_area - intersection + iou = torch.true_divide(intersection, union) + return iou + + def giou(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Calculates the generalized intersection over union. From 16516e55a66e153148804ce9dbfe6e68d083ad5d Mon Sep 17 00:00:00 2001 From: Brian Ko Date: Mon, 21 Dec 2020 10:56:35 +0900 Subject: [PATCH 2/7] Add tests for IoU metric & IoU loss --- tests/losses/test_object_detection.py | 17 ++++++++++++++++- tests/metrics/test_object_detection.py | 17 ++++++++++++++++- 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/tests/losses/test_object_detection.py b/tests/losses/test_object_detection.py index 30f0ab4576..a6117f7eb8 100644 --- a/tests/losses/test_object_detection.py +++ b/tests/losses/test_object_detection.py @@ -5,7 +5,22 @@ import pytest import torch -from pl_bolts.losses.object_detection import giou_loss +from pl_bolts.losses.object_detection import giou_loss, iou_loss + + +@pytest.mark.parametrize("preds, target, expected_loss", [ + (torch.tensor([[100, 100, 200, 200]]), torch.tensor([[100, 100, 200, 200]]), torch.tensor([0.0])) +]) +def test_iou_complete_overlap(preds, target, expected_loss): + torch.testing.assert_allclose(iou_loss(preds, target), expected_loss) + + +@pytest.mark.parametrize("preds, target, expected_loss", [ + (torch.tensor([[100, 100, 200, 200]]), torch.tensor([[100, 200, 200, 300]]), torch.tensor([1.0])), + (torch.tensor([[100, 100, 200, 200]]), torch.tensor([[200, 200, 300, 300]]), torch.tensor([1.0])), +]) +def test_iou_no_overlap(preds, target, expected_loss): + torch.testing.assert_allclose(iou_loss(preds, target), expected_loss) @pytest.mark.parametrize("preds, target, expected_loss", [ diff --git a/tests/metrics/test_object_detection.py b/tests/metrics/test_object_detection.py index a998502314..efe0b7234d 100644 --- a/tests/metrics/test_object_detection.py +++ b/tests/metrics/test_object_detection.py @@ -5,7 +5,22 @@ import pytest import torch -from pl_bolts.metrics.object_detection import giou +from pl_bolts.metrics.object_detection import giou, iou + + +@pytest.mark.parametrize("preds, target, expected_iou", [ + (torch.tensor([[100, 100, 200, 200]]), torch.tensor([[100, 100, 200, 200]]), torch.tensor([1.0])) +]) +def test_iou_complete_overlap(preds, target, expected_iou): + torch.testing.assert_allclose(iou(preds, target), expected_iou) + + +@pytest.mark.parametrize("preds, target, expected_iou", [ + (torch.tensor([[100, 100, 200, 200]]), torch.tensor([[100, 200, 200, 300]]), torch.tensor([0.0])), + (torch.tensor([[100, 100, 200, 200]]), torch.tensor([[200, 200, 300, 300]]), torch.tensor([0.0])), +]) +def test_iou_no_overlap(preds, target, expected_iou): + torch.testing.assert_allclose(iou(preds, target), expected_iou) @pytest.mark.parametrize("preds, target, expected_giou", [ From 4049f1561a87a26e7b7f5768da1ba8e181fe8337 Mon Sep 17 00:00:00 2001 From: Brian Ko Date: Mon, 21 Dec 2020 11:14:17 +0900 Subject: [PATCH 3/7] Add documentation for IoU loss --- docs/source/losses.rst | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/source/losses.rst b/docs/source/losses.rst index 901bb71964..7a0f09aee0 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -27,6 +27,14 @@ GIoU Loss --------------- +IoU Loss +-------- + +.. autofunction:: pl_bolts.losses.object_detection.iou_loss + :noindex: + +--------------- + Reinforcement Learning ====================== These are common losses used in RL. From ce8ebb2f2c3e7ed9a8ef195c0239f04138c1e824 Mon Sep 17 00:00:00 2001 From: Brian Ko Date: Mon, 21 Dec 2020 11:15:11 +0900 Subject: [PATCH 4/7] Update CHANGELOG --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4256c8623b..08496b9146 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added GIoU loss ([#347](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/347)) +- Added IoU loss ([#469](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/469)) + ### Changed - Decoupled datamodules from models ([#332](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/332), [#270](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/270)) From d80a94ec1b89329e1eb3b68435470bd409d03c87 Mon Sep 17 00:00:00 2001 From: Brian Ko Date: Mon, 21 Dec 2020 20:33:57 +0900 Subject: [PATCH 5/7] Update IoU docstring --- pl_bolts/metrics/object_detection.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pl_bolts/metrics/object_detection.py b/pl_bolts/metrics/object_detection.py index e2da6a5eaa..206c8783f8 100644 --- a/pl_bolts/metrics/object_detection.py +++ b/pl_bolts/metrics/object_detection.py @@ -19,7 +19,8 @@ def iou(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: tensor([[0.1429]]) Returns: - IoU value + IoU tensor: an NxM tensor containing the pairwise IoU values for every element in preds and target, + where N is the number of prediction bounding boxes and M is the number of target bounding boxes """ x_min = torch.max(preds[:, None, 0], target[:, 0]) y_min = torch.max(preds[:, None, 1], target[:, 1]) From e64b09ef004612ecf9937860489a46752435e27c Mon Sep 17 00:00:00 2001 From: Brian Ko Date: Wed, 23 Dec 2020 07:43:15 +0900 Subject: [PATCH 6/7] Add tensor shape for IoU docstring --- pl_bolts/metrics/object_detection.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pl_bolts/metrics/object_detection.py b/pl_bolts/metrics/object_detection.py index 206c8783f8..21352888b8 100644 --- a/pl_bolts/metrics/object_detection.py +++ b/pl_bolts/metrics/object_detection.py @@ -6,8 +6,8 @@ def iou(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Calculates the intersection over union. Args: - preds: batch of prediction bounding boxes with representation ``[x_min, y_min, x_max, y_max]`` - target: batch of target bounding boxes with representation ``[x_min, y_min, x_max, y_max]`` + preds: an Nx4 batch of prediction bounding boxes with representation ``[x_min, y_min, x_max, y_max]`` + target: an Mx4 batch of target bounding boxes with representation ``[x_min, y_min, x_max, y_max]`` Example: From 17a3262a4f05428c80f49995c4e29d9e8ac5b17d Mon Sep 17 00:00:00 2001 From: Brian Ko Date: Wed, 23 Dec 2020 09:15:44 +0900 Subject: [PATCH 7/7] Add tests for IoU/GIoU from torchvision --- tests/metrics/test_object_detection.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/metrics/test_object_detection.py b/tests/metrics/test_object_detection.py index efe0b7234d..59b2d8f32e 100644 --- a/tests/metrics/test_object_detection.py +++ b/tests/metrics/test_object_detection.py @@ -23,6 +23,17 @@ def test_iou_no_overlap(preds, target, expected_iou): torch.testing.assert_allclose(iou(preds, target), expected_iou) +@pytest.mark.parametrize("preds, target, expected_iou", [ + ( + torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]), + torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]), + torch.tensor([[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]]) + ) +]) +def test_iou_multi(preds, target, expected_iou): + torch.testing.assert_allclose(iou(preds, target), expected_iou) + + @pytest.mark.parametrize("preds, target, expected_giou", [ (torch.tensor([[100, 100, 200, 200]]), torch.tensor([[100, 100, 200, 200]]), torch.tensor([1.0])) ]) @@ -36,3 +47,14 @@ def test_complete_overlap(preds, target, expected_giou): ]) def test_no_overlap(preds, target, expected_giou): torch.testing.assert_allclose(giou(preds, target), expected_giou) + + +@pytest.mark.parametrize("preds, target, expected_giou", [ + ( + torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]), + torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]), + torch.tensor([[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0]]) + ) +]) +def test_giou_multi(preds, target, expected_giou): + torch.testing.assert_allclose(giou(preds, target), expected_giou)