-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Vectorize box decoding in FCOS #6203
Conversation
Would it be possible to use Same question for batch encoding (which seems not even supported now :( ) In my own code I made these classes derived from nn.Module and then do forward encoding in |
Ooo, I forgot about this method! Thanks @vadimkantorov
oh! I would love to have a look at the code! 😃 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@abhi-glitchhg Thanks a lot for the PR. I've added a few comments as I'm uncertain about some of the changes made. Please let me know what you think. If we end up merging this, we need to confirm that the output before/after doesn't change and also benchmark this to confirm it's faster on CPU/GPU. If you don't have the hardware for the latter, I can do the benchmark using your script on internal infra.
@xiaohu2015 and @zhiqwang would you please be able to have a look and confirm that what we do here is correct?
@vfdev-5 I would also appreciate if you could have a look and perhaps elaborate on the "equal box size" assumption you shared offline with me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have just few nits. Thanks for the PR @abhi-glitchhg
Thanks, @datumbox and @vfdev-5 for your valuable comments!. I checked the correctness of my code with the following script. from torchvision.models.detection._utils import BoxLinearCoder
import time
import torch
box_coder = BoxLinearCoder(True)
for _ in range(10000):
boxes = [torch.randn(3,4) for i in range(5)]
rel_codes = torch.randn(5,3,4)
original = [
box_coder.decode_single(bbox_regression_per_image, anchors_per_image)
for anchors_per_image, bbox_regression_per_image in zip(boxes, rel_codes)
]
original = torch.stack(original)
new = box_coder.decode_all(rel_codes,boxes)
torch.testing.assert_close(original, new, rtol=0, atol=1e-6) This did not raise any error, so I assume the implementation is correct. [But please check this from your side too. :) ] Assuming the above code is correct, I ran the following code to check if the implementation is faster or not. import time
import torch
from torchvision.models.detection._utils import BoxLinearCoder
box_coder = BoxLinearCoder(True)
start = time.process_time()
for _ in range(10000):
boxes = [torch.randn(3,4) for i in range(5)]
rel_codes = torch.randn(5,3,4)
original = [
box_coder.decode_single(bbox_regression_per_image, anchors_per_image)
for anchors_per_image, bbox_regression_per_image in zip(boxes, rel_codes)
]
print(f"time taken for decode_single: {time.process_time() -start}")
start = time.process_time()
for _ in range(10000):
boxes = [torch.randn(3,4) for i in range(5)]
rel_codes = torch.randn(5,3,4)
new = box_coder.decode_all(rel_codes, boxes)
print(f"time taken for decode_all: {time.process_time()-start}")
print('done ') i do not have GPU, so couldn't check on GPU 😅 |
My related feature request on doing the same fix for other box-related functions: #3478
The forward+right_inverse method names are already used in PyTorch for parametrization functionality, API discussion is here: pytorch/pytorch#7313 (comment) . This terminology is also often used in probability flow-based models which are composed of invertible/reversible modules. In a sense, box coding is very much a parametrization as well, and IMO it makes sense to model it in the same terms. It looks like this (my reimpl of box coder for https://github.com/shenyunhang/DRN-WSOD, it encodes boxes in relation to some reference, ground truth boxes): class DrnBoxCoder(nn.Module):
def __init__(self, weights = (10.0, 10.0, 5.0, 5.0), scale_clamp: float = math.log(1000.0 / 16)
):
super().__init__()
self.weights = weights
self.scale_clamp = scale_clamp
def forward(self, x, *, reference_boxes):
x_widths = x[..., 2] - x[..., 0]
x_heights = x[..., 3] - x[..., 1]
x_ctr_x = x[..., 0] + 0.5 * x_widths
x_ctr_y = x[..., 1] + 0.5 * x_heights
reference_widths = reference_boxes[..., 2] - reference_boxes[..., 0]
reference_heights = reference_boxes[..., 3] - reference_boxes[..., 1]
reference_ctr_x = reference_boxes[..., 0] + 0.5 * reference_widths
reference_ctr_y = reference_boxes[..., 1] + 0.5 * reference_heights
wx, wy, ww, wh = self.weights
dx = wx * (reference_ctr_x - x_ctr_x) / x_widths
dy = wy * (reference_ctr_y - x_ctr_y) / x_heights
dw = ww * torch.log(reference_widths / x_widths)
dh = wh * torch.log(reference_heights / x_heights)
y = torch.stack((dx, dy, dw, dh), dim=-1)
assert (x_widths >= 0).all(), "Input boxes to Box2BoxTransform are not valid!"
return y
def inverse(self, y, *, reference_boxes):
y = y.float() # ensure fp32 for decoding precision
reference_boxes = reference_boxes.to(y.dtype)
widths = reference_boxes[..., 2] - reference_boxes[..., 0]
heights = reference_boxes[..., 3] - reference_boxes[..., 1]
ctr_x = reference_boxes[..., 0] + 0.5 * widths
ctr_y = reference_boxes[..., 1] + 0.5 * heights
wx, wy, ww, wh = self.weights
dx = y[..., 0::4] / wx
dy = y[..., 1::4] / wy
dw = y[..., 2::4] / ww
dh = y[..., 3::4] / wh
# Prevent sending too large values into torch.exp()
dw = torch.clamp(dw, max=self.scale_clamp)
dh = torch.clamp(dh, max=self.scale_clamp)
pred_ctr_x = dx * widths[..., None] + ctr_x[..., None]
pred_ctr_y = dy * heights[..., None] + ctr_y[..., None]
pred_w = torch.exp(dw) * widths[..., None]
pred_h = torch.exp(dh) * heights[..., None]
x1 = pred_ctr_x - 0.5 * pred_w
y1 = pred_ctr_y - 0.5 * pred_h
x2 = pred_ctr_x + 0.5 * pred_w
y2 = pred_ctr_y + 0.5 * pred_h
x = torch.stack((x1, y1, x2, y2), dim=-1)
return x.view_as(y) |
Also, it might be good to cosider to port existing box coders from detectron2 to torchvision and use these functions in detectron2. Here is where they are defined: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/box_regression.py |
@abhi-glitchhg good job! I think maybe we should consider merge the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@abhi-glitchhg Thanks a lot for your work on this.
I've run similar checks on my side. On your first script that confirms that we produce the same results, I replaced the assert_close
with equal
to ensure it matches exactly and it does.
For benchmarks, I modified slightly your script to move the random input generation out of the loops and revert the stack that was part of the previous solution. Here is the updated script:
import time
import torch
from torchvision.models.detection._utils import BoxLinearCoder
device = "cpu"
boxes = [torch.randn(3, 4).to(device=device) for i in range(5)]
rel_codes = torch.randn(5, 3, 4).to(device=device)
box_coder = BoxLinearCoder(True)
# Warmup
for _ in range(1000):
original = [
box_coder.decode_single(bbox_regression_per_image, anchors_per_image)
for anchors_per_image, bbox_regression_per_image in zip(boxes, rel_codes)
]
original = torch.stack(original)
new = box_coder.decode_all(rel_codes, boxes)
start = time.process_time()
for _ in range(10000):
original = [
box_coder.decode_single(bbox_regression_per_image, anchors_per_image)
for anchors_per_image, bbox_regression_per_image in zip(boxes, rel_codes)
]
original = torch.stack(original)
print(f"time taken for decode_single: {time.process_time() -start}")
start = time.process_time()
for _ in range(10000):
new = box_coder.decode_all(rel_codes, boxes)
print(f"time taken for decode_all: {time.process_time()-start}")
print('done ')
On CPU I get a 5.5x improvement and on GPU about 5x. So LGTM!
Thanks a lot to everyone for their reviews and suggestions. And +1 on cleaning up the single implementation as it's quite similar to this one. @abhi-glitchhg perhaps you could tackle this on a follow up PR?
Otherwise I think we are good to merge as this optimization seems quite good. :)
Hey @datumbox! You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py |
Summary: * basic structure * added constrains * fixed errors * thanks to vadim! * addressing the comments and added docstrign * Apply suggestions from code review Reviewed By: jdsgomes Differential Revision: D37643906 fbshipit-source-id: 13b15e886d23e259a5ed6926677cac55f160f206 Co-authored-by: Vasilis Vryniotis <[email protected]>
my attempt to fix #5247
cc: @datumbox , @jdsgomes