diff --git a/openmc/bounding_box.py b/openmc/bounding_box.py index e5655cf8cae..6e58ca8ba02 100644 --- a/openmc/bounding_box.py +++ b/openmc/bounding_box.py @@ -95,9 +95,23 @@ def __or__(self, other: BoundingBox) -> BoundingBox: new |= other return new - def __contains__(self, point): - """Check whether or not a point is in the bounding box""" - return all(point > self.lower_left) and all(point < self.upper_right) + def __contains__(self, other): + """Check whether or not a point or another bounding box is in the bounding box. + + For another bounding box to be in the parent it must lie fully inside of it. + """ + # test for a single point + if isinstance(other, (tuple, list, np.ndarray)): + point = other + check_length("Point", point, 3, 3) + return all(point > self.lower_left) and all(point < self.upper_right) + elif isinstance(other, BoundingBox): + return all([p in self for p in [other.lower_left, other.upper_right]]) + else: + raise TypeError( + f"Unable to determine if {other} is in the bounding box." + f" Expected a tuple or a bounding box, but {type(other)} given" + ) @property def center(self) -> np.ndarray: diff --git a/tests/unit_tests/test_bounding_box.py b/tests/unit_tests/test_bounding_box.py index 89e50427151..57c880092e3 100644 --- a/tests/unit_tests/test_bounding_box.py +++ b/tests/unit_tests/test_bounding_box.py @@ -78,9 +78,9 @@ def test_bounding_box_input_checking(): def test_bounding_box_extents(): - assert test_bb_1.extent['xy'] == (-10., 1., -20., 2.) - assert test_bb_1.extent['xz'] == (-10., 1., -30., 3.) - assert test_bb_1.extent['yz'] == (-20., 2., -30., 3.) + assert test_bb_1.extent["xy"] == (-10.0, 1.0, -20.0, 2.0) + assert test_bb_1.extent["xz"] == (-10.0, 1.0, -30.0, 3.0) + assert test_bb_1.extent["yz"] == (-20.0, 2.0, -30.0, 3.0) def test_bounding_box_methods(): @@ -156,3 +156,35 @@ def test_bounding_box_methods(): assert all(test_bb[0] == [-50.1, -50.1, -12.1]) assert all(test_bb[1] == [50.1, 14.1, 50.1]) + + +@pytest.mark.parametrize( + "bb, other, expected", + [ + (test_bb_1, (0, 0, 0), True), + (test_bb_2, (3, 3, 3), False), + # completely disjoint + (test_bb_1, test_bb_2, False), + # contained but touching border + (test_bb_1, test_bb_3, False), + # Fully contained + (test_bb_1, openmc.BoundingBox((-9, -19, -29), (0, 0, 0)), True), + # intersecting boxes + (test_bb_1, openmc.BoundingBox((-9, -19, -29), (1, 2, 5)), False), + ], +) +def test_bounding_box_contains(bb, other, expected): + assert (other in bb) == expected + + +@pytest.mark.parametrize( + "invalid, ex", + [ + ((1, 0), ValueError), + ((1, 2, 3, 4), ValueError), + ("foo", TypeError), + ], +) +def test_bounding_box_contains_checking(invalid, ex): + with pytest.raises(ex): + invalid in test_bb_1