Skip to content

Commit

Permalink
Fix RoIPool reference implementation in Python 2
Browse files Browse the repository at this point in the history
Also fixes a bug in the clip_boxes_to_image -- this function needs a test!
  • Loading branch information
fmassa committed May 7, 2019
1 parent dfe8ec1 commit 4bdc54c
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
2 changes: 1 addition & 1 deletion test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def slow_roi_pooling(self, x, rois, pool_h, pool_w, spatial_scale=1,
start_h, end_h = int(roi[2].item()), int(roi[4].item()) + 1
start_w, end_w = int(roi[1].item()), int(roi[3].item()) + 1
roi_x = x[roi[0].long(), :, start_h:end_h, start_w:end_w]
bin_h, bin_w = roi_x.size(-2) / pool_h, roi_x.size(-1) / pool_w
bin_h, bin_w = roi_x.size(-2) / float(pool_h), roi_x.size(-1) / float(pool_w)

for j in range(0, pool_h):
cj = slice(int(np.floor(j * bin_h)), int(np.ceil((j + 1) * bin_h)))
Expand Down
11 changes: 6 additions & 5 deletions torchvision/ops/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,14 @@ def clip_boxes_to_image(boxes, size):
Returns:
clipped_boxes (Tensor[N, 4])
"""
boxes_x = boxes[:, 0::2]
boxes_y = boxes[:, 1::2]
dim = boxes.dim()
boxes_x = boxes[..., 0::2]
boxes_y = boxes[..., 1::2]
height, width = size
boxes_x = boxes_x.clamp(min=0, max=width)
boxes_y = boxes_x.clamp(min=0, max=height)
clipped_boxes = torch.stack((boxes_x, boxes_y), dim=2)
return clipped_boxes.reshape(-1, 4)
boxes_y = boxes_y.clamp(min=0, max=height)
clipped_boxes = torch.stack((boxes_x, boxes_y), dim=dim)
return clipped_boxes.reshape(boxes.shape)


def box_area(boxes):
Expand Down

0 comments on commit 4bdc54c

Please sign in to comment.