Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix is_complete_for check #587

Merged
merged 4 commits into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions legate/core/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def is_complete_for(self, extents: Shape, offsets: Shape) -> bool:
my_lo = self._offset
my_hi = self._offset + self.tile_shape * self._color_shape

return my_lo <= offsets and offsets + extents <= my_hi
return all(my_lo <= offsets) and all(offsets + extents <= my_hi)

def is_disjoint_for(self, launch_domain: Optional[Rect]) -> bool:
return (
Expand All @@ -236,7 +236,7 @@ def is_disjoint_for(self, launch_domain: Optional[Rect]) -> bool:
)

def has_color(self, color: Shape) -> bool:
return color >= 0 and color < self._color_shape
return all(color >= 0) and all(color < self._color_shape)

@lru_cache
def get_subregion_size(self, extents: Shape, color: Shape) -> Shape:
Expand Down Expand Up @@ -396,7 +396,7 @@ def is_disjoint_for(self, launch_domain: Optional[Rect]) -> bool:
return True

def has_color(self, color: Shape) -> bool:
return color >= 0 and color < self._color_shape
return all(color >= 0) and all(color < self._color_shape)

def translate(self, offset: Shape) -> None:
raise NotImplementedError("This method shouldn't be invoked")
Expand Down
25 changes: 17 additions & 8 deletions legate/core/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ def _cast_tuple(value: int | Iterable[int], ndim: int) -> tuple[int, ...]:
return tuple(value)


class _ShapeComparisonResult(tuple[bool, ...]):
def __bool__(self) -> bool:
assert False, "use any() or all()"


class Shape:
_extents: Union[tuple[int, ...], None]
_ispace: Union[IndexSpace, None]
Expand Down Expand Up @@ -154,41 +159,45 @@ def __eq__(self, other: object) -> bool:
else:
return False

def __le__(self, other: ExtentLike) -> bool:
def __le__(self, other: ExtentLike) -> _ShapeComparisonResult:
lh = self.extents
rh = (
other.extents
if isinstance(other, Shape)
else _cast_tuple(other, self.ndim)
)
return len(lh) == len(rh) and lh <= rh
assert len(lh) == len(rh)
return _ShapeComparisonResult(l <= r for (l, r) in zip(lh, rh))

def __lt__(self, other: ExtentLike) -> bool:
def __lt__(self, other: ExtentLike) -> _ShapeComparisonResult:
lh = self.extents
rh = (
other.extents
if isinstance(other, Shape)
else _cast_tuple(other, self.ndim)
)
return len(lh) == len(rh) and lh < rh
assert len(lh) == len(rh)
return _ShapeComparisonResult(l < r for (l, r) in zip(lh, rh))

def __ge__(self, other: ExtentLike) -> bool:
def __ge__(self, other: ExtentLike) -> _ShapeComparisonResult:
lh = self.extents
rh = (
other.extents
if isinstance(other, Shape)
else _cast_tuple(other, self.ndim)
)
return len(lh) == len(rh) and lh >= rh
assert len(lh) == len(rh)
return _ShapeComparisonResult(l >= r for (l, r) in zip(lh, rh))

def __gt__(self, other: ExtentLike) -> bool:
def __gt__(self, other: ExtentLike) -> _ShapeComparisonResult:
lh = self.extents
rh = (
other.extents
if isinstance(other, Shape)
else _cast_tuple(other, self.ndim)
)
return len(lh) == len(rh) and lh > rh
assert len(lh) == len(rh)
return _ShapeComparisonResult(l > r for (l, r) in zip(lh, rh))

def __add__(self, other: ExtentLike) -> Shape:
lh = self.extents
Expand Down