Skip to content

Commit

Permalink
Fix bug pointed by @lopuhin
Browse files Browse the repository at this point in the history
  • Loading branch information
fmassa committed May 7, 2019
1 parent bf5c1d1 commit dfe8ec1
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 9 deletions.
39 changes: 35 additions & 4 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def setUpClass(cls):

def slow_roi_pooling(self, x, rois, pool_h, pool_w, spatial_scale=1,
device=torch.device('cpu'), dtype=torch.float64):
y = torch.zeros(rois.size(0), x.size(1), pool_h, pool_w, dtype=dtype, device=device)
c = x.size(1)
y = torch.zeros(rois.size(0), c, pool_h, pool_w, dtype=dtype, device=device)

rois = torch.round(rois * spatial_scale)

Expand All @@ -24,14 +25,16 @@ def slow_roi_pooling(self, x, rois, pool_h, pool_w, spatial_scale=1,
if roi[0] == n:
start_h, end_h = int(roi[2].item()), int(roi[4].item()) + 1
start_w, end_w = int(roi[1].item()), int(roi[3].item()) + 1
roi_x = x[roi[0].long():roi[0].long() + 1, :, start_h:end_h, start_w:end_w]
bin_h, bin_w = roi_x.size(2) / pool_h, roi_x.size(3) / pool_w
roi_x = x[roi[0].long(), :, start_h:end_h, start_w:end_w]
bin_h, bin_w = roi_x.size(-2) / pool_h, roi_x.size(-1) / pool_w

for j in range(0, pool_h):
cj = slice(int(np.floor(j * bin_h)), int(np.ceil((j + 1) * bin_h)))
for i in range(0, pool_w):
ci = slice(int(np.floor(i * bin_w)), int(np.ceil((i + 1) * bin_w)))
y[r, :, j, i] = torch.max(y[r, :, j, i], torch.max(roi_x[:, :, cj, ci]))
t = roi_x[:, cj, ci].reshape(c, -1)
if t.numel() > 0:
y[r, :, j, i] = torch.max(t, 1)[0]
return y

def test_roi_pool_basic_cpu(self):
Expand Down Expand Up @@ -75,6 +78,34 @@ def test_roi_pool_cpu(self):
gt_y = self.slow_roi_pooling(x.permute(0, 1, 3, 2), rois, pool_h, pool_w, device=device, dtype=self.dtype)
assert torch.allclose(gt_y, y), 'RoIPool layer incorrect on CPU for batch > 1'

def test_roi_pool_cpu_empty_rois(self):
device = torch.device('cpu')
x = torch.tensor(
[[[[0.1767, 1.2851, 4.2325, 4.8645, 7.1496]],
[[2.5916, 4.3361, 3.8143, 6.1329, 2.0230]],
[[1.4492, 3.3384, 4.0816, 6.3116, 5.1068]]]],
dtype=self.dtype, device=device)
rois = torch.tensor(
[[0., 1., 0., 4., 0.],
[0., 2., 0., 3., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 2., 0., 2., 0.]],
dtype=self.dtype, device=device)

pool_h, pool_w = (1, 2)
roi_pool = ops.RoIPool((pool_h, pool_w), 1)
y = roi_pool(x, rois)

gt_y = self.slow_roi_pooling(x, rois, pool_h, pool_w, device=device, dtype=self.dtype)

assert torch.allclose(gt_y, y), 'RoIPool layer incorrect on CPU empty rois'

# non-contiguous
y = roi_pool(x.permute(0, 1, 3, 2), rois)
gt_y = self.slow_roi_pooling(x.permute(0, 1, 3, 2), rois, pool_h, pool_w, device=device, dtype=self.dtype)
assert torch.allclose(gt_y, y), 'RoIPool layer incorrect on CPU for empty rois non-contiguous'

def test_roi_pool_gradient_cpu(self):
device = torch.device('cpu')
x = torch.ones(1, 1, 10, 10, dtype=self.dtype, device=device, requires_grad=True)
Expand Down
10 changes: 5 additions & 5 deletions torchvision/csrc/cpu/ROIPool_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ void RoIPoolForward(
wend = std::min(std::max(wend + roi_start_w, 0), width);
bool is_empty = (hend <= hstart) || (wend <= wstart);

// Define an empty pooling region to be zero
T maxval = is_empty ? 0 : -FLT_MAX;
// If nothing is pooled, argmax = -1 causes nothing to be backprop'd
int maxidx = -1;

for (int c = 0; c < channels; ++c) {
// Define an empty pooling region to be zero
T maxval = is_empty ? 0 : -FLT_MAX;
// If nothing is pooled, argmax = -1 causes nothing to be backprop'd
int maxidx = -1;

const T* input_offset =
input + (roi_batch_ind * channels + c) * height * width;

Expand Down

0 comments on commit dfe8ec1

Please sign in to comment.