Skip to content

Commit

Permalink
Fix UpsampleNearest op CPU impl batch handling (pytorch#13002)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#13002

Batch dim wasn't handled in the CPU impl (will fail for inputs with N > 1).
Fixing that here.

Differential Revision: D10515159

fbshipit-source-id: ee7e4f489d2d4de793f550b31db7c0e2ba3651e8
  • Loading branch information
viswanathgs authored and facebook-github-bot committed Oct 24, 2018
1 parent 353fdef commit 1bea5fc
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 1 deletion.
2 changes: 1 addition & 1 deletion modules/detectron/upsample_nearest_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class UpsampleNearestOp final : public Operator<Context> {
d2 = Y->dim32(1);
d3 = Y->dim32(2);
} else {
d1 = Y->dim32(1);
d1 = Y->dim32(0) * Y->dim32(1);
d2 = Y->dim32(2);
d3 = Y->dim32(3);
}
Expand Down
43 changes: 43 additions & 0 deletions modules/detectron/upsample_nearest_op_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from __future__ import absolute_import, division, print_function, unicode_literals

import unittest

import caffe2.python.hypothesis_test_util as hu
import hypothesis.strategies as st
import numpy as np
from caffe2.python import core, dyndep
from hypothesis import given


dyndep.InitOpsLibrary("@/caffe2/modules/detectron:detectron_ops")


class TestUpsampleNearestOp(hu.HypothesisTestCase):
@given(
N=st.integers(1, 3),
H=st.integers(10, 300),
W=st.integers(10, 300),
scale=st.integers(1, 3),
**hu.gcs
)
def test_upsample_nearest_op(self, N, H, W, scale, gc, dc):
C = 32
X = np.random.randn(N, C, H, W).astype(np.float32)
op = core.CreateOperator("UpsampleNearest", ["X"], ["Y"], scale=scale)

def ref(X):
outH = H * scale
outW = W * scale
outH_idxs, outW_idxs = np.meshgrid(
np.arange(outH), np.arange(outW), indexing="ij"
)
inH_idxs = (outH_idxs / scale).astype(np.int32)
inW_idxs = (outW_idxs / scale).astype(np.int32)
Y = X[:, :, inH_idxs, inW_idxs]
return [Y]

self.assertReferenceChecks(device_option=gc, op=op, inputs=[X], reference=ref)


if __name__ == "__main__":
unittest.main()

0 comments on commit 1bea5fc

Please sign in to comment.