diff --git a/legate/core/partition.py b/legate/core/partition.py index d5904c319..c0eb24ad6 100644 --- a/legate/core/partition.py +++ b/legate/core/partition.py @@ -227,7 +227,7 @@ def is_complete_for(self, extents: Shape, offsets: Shape) -> bool: my_lo = self._offset my_hi = self._offset + self.tile_shape * self._color_shape - return my_lo <= offsets and offsets + extents <= my_hi + return all(my_lo <= offsets) and all(offsets + extents <= my_hi) def is_disjoint_for(self, launch_domain: Optional[Rect]) -> bool: return ( @@ -236,7 +236,7 @@ def is_disjoint_for(self, launch_domain: Optional[Rect]) -> bool: ) def has_color(self, color: Shape) -> bool: - return color >= 0 and color < self._color_shape + return all(color >= 0) and all(color < self._color_shape) @lru_cache def get_subregion_size(self, extents: Shape, color: Shape) -> Shape: @@ -396,7 +396,7 @@ def is_disjoint_for(self, launch_domain: Optional[Rect]) -> bool: return True def has_color(self, color: Shape) -> bool: - return color >= 0 and color < self._color_shape + return all(color >= 0) and all(color < self._color_shape) def translate(self, offset: Shape) -> None: raise NotImplementedError("This method shouldn't be invoked") diff --git a/legate/core/shape.py b/legate/core/shape.py index 98207191f..cd6c44ee5 100644 --- a/legate/core/shape.py +++ b/legate/core/shape.py @@ -32,6 +32,11 @@ def _cast_tuple(value: int | Iterable[int], ndim: int) -> tuple[int, ...]: return tuple(value) +class _ShapeComparisonResult(tuple[bool, ...]): + def __bool__(self) -> bool: + assert False, "use any() or all()" + + class Shape: _extents: Union[tuple[int, ...], None] _ispace: Union[IndexSpace, None] @@ -154,41 +159,45 @@ def __eq__(self, other: object) -> bool: else: return False - def __le__(self, other: ExtentLike) -> bool: + def __le__(self, other: ExtentLike) -> _ShapeComparisonResult: lh = self.extents rh = ( other.extents if isinstance(other, Shape) else _cast_tuple(other, self.ndim) ) - return len(lh) == len(rh) and lh <= rh + assert len(lh) == len(rh) + return _ShapeComparisonResult(l <= r for (l, r) in zip(lh, rh)) - def __lt__(self, other: ExtentLike) -> bool: + def __lt__(self, other: ExtentLike) -> _ShapeComparisonResult: lh = self.extents rh = ( other.extents if isinstance(other, Shape) else _cast_tuple(other, self.ndim) ) - return len(lh) == len(rh) and lh < rh + assert len(lh) == len(rh) + return _ShapeComparisonResult(l < r for (l, r) in zip(lh, rh)) - def __ge__(self, other: ExtentLike) -> bool: + def __ge__(self, other: ExtentLike) -> _ShapeComparisonResult: lh = self.extents rh = ( other.extents if isinstance(other, Shape) else _cast_tuple(other, self.ndim) ) - return len(lh) == len(rh) and lh >= rh + assert len(lh) == len(rh) + return _ShapeComparisonResult(l >= r for (l, r) in zip(lh, rh)) - def __gt__(self, other: ExtentLike) -> bool: + def __gt__(self, other: ExtentLike) -> _ShapeComparisonResult: lh = self.extents rh = ( other.extents if isinstance(other, Shape) else _cast_tuple(other, self.ndim) ) - return len(lh) == len(rh) and lh > rh + assert len(lh) == len(rh) + return _ShapeComparisonResult(l > r for (l, r) in zip(lh, rh)) def __add__(self, other: ExtentLike) -> Shape: lh = self.extents