Skip to content

Commit

Permalink
Fix TestAnchorGenerator
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Dec 20, 2021
1 parent 3cf8abf commit ed6a448
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions test/test_models_anchor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,13 @@ def test_anchor_generator(self):
model.eval()
anchors = model(features)

expected_anchor_output = torch.tensor([[-0.5, -0.5], [0.5, -0.5], [-0.5, 0.5], [0.5, 0.5]])
expected_wh_output = torch.tensor([[4.0], [4.0], [4.0], [4.0]])
expected_xy_output = torch.tensor([[6.0, 14.0], [6.0, 14.0], [6.0, 14.0], [6.0, 14.0]])
expected_grids = torch.tensor([[[[[0.0, 0.0], [1.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]]]])
expected_shifts = torch.tensor([[[[[6.0, 14.0], [6.0, 14.0]], [[6.0, 14.0], [6.0, 14.0]]]]])

assert len(anchors) == 3
assert tuple(anchors[0].shape) == (4, 2)
assert tuple(anchors[1].shape) == (4, 1)
assert tuple(anchors[2].shape) == (4, 2)
assert len(anchors) == 2
assert len(anchors[0]) == len(anchors[1]) == 1
assert tuple(anchors[0][0].shape) == (1, 1, 2, 2, 2)
assert tuple(anchors[1][0].shape) == (1, 1, 2, 2, 2)

torch.testing.assert_close(anchors[0], expected_anchor_output, rtol=0, atol=0)
torch.testing.assert_close(anchors[1], expected_wh_output, rtol=0, atol=0)
torch.testing.assert_close(anchors[2], expected_xy_output, rtol=0, atol=0)
torch.testing.assert_close(anchors[0][0], expected_grids, rtol=0, atol=0)
torch.testing.assert_close(anchors[1][0], expected_shifts, rtol=0, atol=0)

0 comments on commit ed6a448

Please sign in to comment.