Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorize box decoding in FCOS #6203

Merged
merged 10 commits into from
Jul 5, 2022
Merged

Conversation

abhi-glitchhg
Copy link
Contributor

@abhi-glitchhg abhi-glitchhg commented Jun 25, 2022

my attempt to fix #5247

cc: @datumbox , @jdsgomes

@abhi-glitchhg abhi-glitchhg marked this pull request as draft June 25, 2022 04:53
@vadimkantorov
Copy link

vadimkantorov commented Jun 30, 2022

Would it be possible to use ... instead of :, :? This way is more flexible to number of batch dimensions.

Same question for batch encoding (which seems not even supported now :( )

In my own code I made these classes derived from nn.Module and then do forward encoding in forward() and decoding in inverse()

@abhi-glitchhg
Copy link
Contributor Author

abhi-glitchhg commented Jul 3, 2022

Would it be possible to use ... instead of :, :? This way is more flexible to number of batch dimensions.

Ooo, I forgot about this method! Thanks @vadimkantorov

In my own code I made these classes derived from nn.Module and then do forward encoding in forward() and decoding in inverse()

oh! I would love to have a look at the code! 😃

@abhi-glitchhg abhi-glitchhg marked this pull request as ready for review July 3, 2022 09:59
@abhi-glitchhg abhi-glitchhg changed the title [WIP] Vectorize box decoding in FCOS Vectorize box decoding in FCOS Jul 3, 2022
Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@abhi-glitchhg Thanks a lot for the PR. I've added a few comments as I'm uncertain about some of the changes made. Please let me know what you think. If we end up merging this, we need to confirm that the output before/after doesn't change and also benchmark this to confirm it's faster on CPU/GPU. If you don't have the hardware for the latter, I can do the benchmark using your script on internal infra.

@xiaohu2015 and @zhiqwang would you please be able to have a look and confirm that what we do here is correct?

@vfdev-5 I would also appreciate if you could have a look and perhaps elaborate on the "equal box size" assumption you shared offline with me.

torchvision/models/detection/_utils.py Outdated Show resolved Hide resolved
torchvision/models/detection/_utils.py Outdated Show resolved Hide resolved
torchvision/models/detection/_utils.py Outdated Show resolved Hide resolved
torchvision/models/detection/_utils.py Outdated Show resolved Hide resolved
torchvision/models/detection/_utils.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have just few nits. Thanks for the PR @abhi-glitchhg

torchvision/models/detection/fcos.py Outdated Show resolved Hide resolved
torchvision/models/detection/_utils.py Outdated Show resolved Hide resolved
@abhi-glitchhg
Copy link
Contributor Author

abhi-glitchhg commented Jul 4, 2022

Thanks, @datumbox and @vfdev-5 for your valuable comments!.

I checked the correctness of my code with the following script.

from torchvision.models.detection._utils import BoxLinearCoder
import time 
import torch

box_coder = BoxLinearCoder(True)

for _ in range(10000):
    boxes = [torch.randn(3,4) for i in range(5)]
    rel_codes = torch.randn(5,3,4)
    

    original = [
            box_coder.decode_single(bbox_regression_per_image, anchors_per_image)
            for anchors_per_image, bbox_regression_per_image in zip(boxes, rel_codes)
        ]

    original  = torch.stack(original)
    new = box_coder.decode_all(rel_codes,boxes)

    torch.testing.assert_close(original, new, rtol=0, atol=1e-6)

This did not raise any error, so I assume the implementation is correct. [But please check this from your side too. :) ]

Assuming the above code is correct, I ran the following code to check if the implementation is faster or not.

import time
import torch 
from torchvision.models.detection._utils import BoxLinearCoder

box_coder = BoxLinearCoder(True)
start = time.process_time()
for _ in range(10000):
    boxes = [torch.randn(3,4) for i in range(5)]
    rel_codes = torch.randn(5,3,4)

    original = [
            box_coder.decode_single(bbox_regression_per_image, anchors_per_image)
            for anchors_per_image, bbox_regression_per_image in zip(boxes, rel_codes)
        ]
print(f"time taken for decode_single: {time.process_time() -start}")

start = time.process_time()
for _ in range(10000):
    boxes = [torch.randn(3,4) for i in range(5)]
    rel_codes = torch.randn(5,3,4)
    new = box_coder.decode_all(rel_codes, boxes)

print(f"time taken for decode_all: {time.process_time()-start}")

print('done ')

And results are as following:
image

i do not have GPU, so couldn't check on GPU 😅

@vadimkantorov
Copy link

vadimkantorov commented Jul 4, 2022

Would it be possible to use ... instead of :, :? This way is more flexible to number of batch dimensions.

Ooo, I forgot about this method! Thanks @vadimkantorov

My related feature request on doing the same fix for other box-related functions: #3478

In my own code I made these classes derived from nn.Module and then do forward encoding in forward() and decoding in inverse()

oh! I would love to have a look at the code! 😃

The forward+right_inverse method names are already used in PyTorch for parametrization functionality, API discussion is here: pytorch/pytorch#7313 (comment) . This terminology is also often used in probability flow-based models which are composed of invertible/reversible modules. In a sense, box coding is very much a parametrization as well, and IMO it makes sense to model it in the same terms.

It looks like this (my reimpl of box coder for https://github.com/shenyunhang/DRN-WSOD, it encodes boxes in relation to some reference, ground truth boxes):

class DrnBoxCoder(nn.Module):
    def __init__(self, weights = (10.0, 10.0, 5.0, 5.0), scale_clamp: float = math.log(1000.0 / 16)
    ):
        super().__init__()
        self.weights = weights
        self.scale_clamp = scale_clamp


    def forward(self, x, *, reference_boxes):
        x_widths = x[..., 2] - x[..., 0]
        x_heights = x[..., 3] - x[..., 1]
        x_ctr_x = x[..., 0] + 0.5 * x_widths
        x_ctr_y = x[..., 1] + 0.5 * x_heights

        reference_widths = reference_boxes[..., 2] - reference_boxes[..., 0]
        reference_heights = reference_boxes[..., 3] - reference_boxes[..., 1]
        reference_ctr_x = reference_boxes[..., 0] + 0.5 * reference_widths
        reference_ctr_y = reference_boxes[..., 1] + 0.5 * reference_heights

        wx, wy, ww, wh = self.weights
        dx = wx * (reference_ctr_x - x_ctr_x) / x_widths
        dy = wy * (reference_ctr_y - x_ctr_y) / x_heights
        dw = ww * torch.log(reference_widths / x_widths)
        dh = wh * torch.log(reference_heights / x_heights)

        y = torch.stack((dx, dy, dw, dh), dim=-1)
        assert (x_widths >= 0).all(), "Input boxes to Box2BoxTransform are not valid!"
        return y

    def inverse(self, y, *, reference_boxes):
        y = y.float()  # ensure fp32 for decoding precision
        reference_boxes = reference_boxes.to(y.dtype)

        widths = reference_boxes[..., 2] - reference_boxes[..., 0]
        heights = reference_boxes[..., 3] - reference_boxes[..., 1]
        ctr_x = reference_boxes[..., 0] + 0.5 * widths
        ctr_y = reference_boxes[..., 1] + 0.5 * heights

        wx, wy, ww, wh = self.weights
        dx = y[..., 0::4] / wx
        dy = y[..., 1::4] / wy
        dw = y[..., 2::4] / ww
        dh = y[..., 3::4] / wh

        # Prevent sending too large values into torch.exp()
        dw = torch.clamp(dw, max=self.scale_clamp)
        dh = torch.clamp(dh, max=self.scale_clamp)

        pred_ctr_x = dx * widths[..., None] + ctr_x[..., None]
        pred_ctr_y = dy * heights[..., None] + ctr_y[..., None]
        pred_w = torch.exp(dw) * widths[..., None]
        pred_h = torch.exp(dh) * heights[..., None]

        x1 = pred_ctr_x - 0.5 * pred_w
        y1 = pred_ctr_y - 0.5 * pred_h
        x2 = pred_ctr_x + 0.5 * pred_w
        y2 = pred_ctr_y + 0.5 * pred_h
        x = torch.stack((x1, y1, x2, y2), dim=-1)
        return x.view_as(y)

@vadimkantorov
Copy link

vadimkantorov commented Jul 4, 2022

Also, it might be good to cosider to port existing box coders from detectron2 to torchvision and use these functions in detectron2. Here is where they are defined: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/box_regression.py

@xiaohu2015
Copy link
Contributor

xiaohu2015 commented Jul 5, 2022

@abhi-glitchhg good job! I think maybe we should consider merge the decode_single and decode_all because they share nearly same code.

torchvision/models/detection/_utils.py Outdated Show resolved Hide resolved
torchvision/models/detection/_utils.py Outdated Show resolved Hide resolved
torchvision/models/detection/fcos.py Outdated Show resolved Hide resolved
Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@abhi-glitchhg Thanks a lot for your work on this.

I've run similar checks on my side. On your first script that confirms that we produce the same results, I replaced the assert_close with equal to ensure it matches exactly and it does.

For benchmarks, I modified slightly your script to move the random input generation out of the loops and revert the stack that was part of the previous solution. Here is the updated script:

import time
import torch
from torchvision.models.detection._utils import BoxLinearCoder

device = "cpu"
boxes = [torch.randn(3, 4).to(device=device) for i in range(5)]
rel_codes = torch.randn(5, 3, 4).to(device=device)

box_coder = BoxLinearCoder(True)

# Warmup

for _ in range(1000):
    original = [
            box_coder.decode_single(bbox_regression_per_image, anchors_per_image)
            for anchors_per_image, bbox_regression_per_image in zip(boxes, rel_codes)
        ]
    original = torch.stack(original)
    new = box_coder.decode_all(rel_codes, boxes)


start = time.process_time()
for _ in range(10000):
    original = [
            box_coder.decode_single(bbox_regression_per_image, anchors_per_image)
            for anchors_per_image, bbox_regression_per_image in zip(boxes, rel_codes)
        ]
    original = torch.stack(original)
print(f"time taken for decode_single: {time.process_time() -start}")

start = time.process_time()
for _ in range(10000):
    new = box_coder.decode_all(rel_codes, boxes)

print(f"time taken for decode_all: {time.process_time()-start}")

print('done ')

On CPU I get a 5.5x improvement and on GPU about 5x. So LGTM!

Thanks a lot to everyone for their reviews and suggestions. And +1 on cleaning up the single implementation as it's quite similar to this one. @abhi-glitchhg perhaps you could tackle this on a follow up PR?

Otherwise I think we are good to merge as this optimization seems quite good. :)

@datumbox datumbox merged commit b3b7448 into pytorch:main Jul 5, 2022
@github-actions
Copy link

github-actions bot commented Jul 5, 2022

Hey @datumbox!

You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py

facebook-github-bot pushed a commit that referenced this pull request Jul 6, 2022
Summary:
* basic structure

* added constrains

* fixed errors

* thanks to vadim!

* addressing the comments and added docstrign

* Apply suggestions from code review

Reviewed By: jdsgomes

Differential Revision: D37643906

fbshipit-source-id: 13b15e886d23e259a5ed6926677cac55f160f206

Co-authored-by: Vasilis Vryniotis <[email protected]>
@abhi-glitchhg abhi-glitchhg deleted the vectorize branch August 8, 2022 18:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Vectorize box decoding in FCOS model
7 participants