Skip to content

Commit

Permalink
Fix is_complete_for check (#587)
Browse files Browse the repository at this point in the history
* Fix is_complete_for check

* Fix another use of inequality on Shape

* Return a bool tuple from Shape comparisons

---------

Co-authored-by: Manolis Papadakis <[email protected]>
  • Loading branch information
manopapad and manopapad authored Feb 28, 2023
1 parent 6782f21 commit 33ce363
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 11 deletions.
6 changes: 3 additions & 3 deletions legate/core/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def is_complete_for(self, extents: Shape, offsets: Shape) -> bool:
my_lo = self._offset
my_hi = self._offset + self.tile_shape * self._color_shape

return my_lo <= offsets and offsets + extents <= my_hi
return all(my_lo <= offsets) and all(offsets + extents <= my_hi)

def is_disjoint_for(self, launch_domain: Optional[Rect]) -> bool:
return (
Expand All @@ -236,7 +236,7 @@ def is_disjoint_for(self, launch_domain: Optional[Rect]) -> bool:
)

def has_color(self, color: Shape) -> bool:
return color >= 0 and color < self._color_shape
return all(color >= 0) and all(color < self._color_shape)

@lru_cache
def get_subregion_size(self, extents: Shape, color: Shape) -> Shape:
Expand Down Expand Up @@ -396,7 +396,7 @@ def is_disjoint_for(self, launch_domain: Optional[Rect]) -> bool:
return True

def has_color(self, color: Shape) -> bool:
return color >= 0 and color < self._color_shape
return all(color >= 0) and all(color < self._color_shape)

def translate(self, offset: Shape) -> None:
raise NotImplementedError("This method shouldn't be invoked")
Expand Down
25 changes: 17 additions & 8 deletions legate/core/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ def _cast_tuple(value: int | Iterable[int], ndim: int) -> tuple[int, ...]:
return tuple(value)


class _ShapeComparisonResult(tuple[bool, ...]):
def __bool__(self) -> bool:
assert False, "use any() or all()"


class Shape:
_extents: Union[tuple[int, ...], None]
_ispace: Union[IndexSpace, None]
Expand Down Expand Up @@ -154,41 +159,45 @@ def __eq__(self, other: object) -> bool:
else:
return False

def __le__(self, other: ExtentLike) -> bool:
def __le__(self, other: ExtentLike) -> _ShapeComparisonResult:
lh = self.extents
rh = (
other.extents
if isinstance(other, Shape)
else _cast_tuple(other, self.ndim)
)
return len(lh) == len(rh) and lh <= rh
assert len(lh) == len(rh)
return _ShapeComparisonResult(l <= r for (l, r) in zip(lh, rh))

def __lt__(self, other: ExtentLike) -> bool:
def __lt__(self, other: ExtentLike) -> _ShapeComparisonResult:
lh = self.extents
rh = (
other.extents
if isinstance(other, Shape)
else _cast_tuple(other, self.ndim)
)
return len(lh) == len(rh) and lh < rh
assert len(lh) == len(rh)
return _ShapeComparisonResult(l < r for (l, r) in zip(lh, rh))

def __ge__(self, other: ExtentLike) -> bool:
def __ge__(self, other: ExtentLike) -> _ShapeComparisonResult:
lh = self.extents
rh = (
other.extents
if isinstance(other, Shape)
else _cast_tuple(other, self.ndim)
)
return len(lh) == len(rh) and lh >= rh
assert len(lh) == len(rh)
return _ShapeComparisonResult(l >= r for (l, r) in zip(lh, rh))

def __gt__(self, other: ExtentLike) -> bool:
def __gt__(self, other: ExtentLike) -> _ShapeComparisonResult:
lh = self.extents
rh = (
other.extents
if isinstance(other, Shape)
else _cast_tuple(other, self.ndim)
)
return len(lh) == len(rh) and lh > rh
assert len(lh) == len(rh)
return _ShapeComparisonResult(l > r for (l, r) in zip(lh, rh))

def __add__(self, other: ExtentLike) -> Shape:
lh = self.extents
Expand Down

0 comments on commit 33ce363

Please sign in to comment.