Skip to content

Commit

Permalink
Make sure to call the superclass' __init__() on a newly created insta…
Browse files Browse the repository at this point in the history
…nce in PositionalSharding._remake().

If we don't do this, the C++ base class is left in an uninitialized state, leading to failures elsewhere in the test suite.

PiperOrigin-RevId: 673411282
  • Loading branch information
hawkinsp authored and Google-ML-Automation committed Sep 11, 2024
1 parent 2bd1fde commit ed849ff
Showing 1 changed file with 3 additions and 7 deletions.
10 changes: 3 additions & 7 deletions jax/_src/sharding_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,13 +691,9 @@ def check_compatible_aval(self, aval_shape: Shape) -> None:
def _remake(
cls, devices: tuple[xc.Device, ...], ids: np.ndarray,
*, memory_kind: str | None = None) -> PositionalSharding:
self = cls.__new__(cls)
self._devices = devices
self._ids = ids
self._internal_device_list = xc.DeviceList(self._devices)
self._memory_kind = xc.check_and_canonicalize_memory_kind(
memory_kind, self._internal_device_list)
return self
sharding = cls(devices, memory_kind=memory_kind)
sharding._ids = ids
return sharding

# Hashable

Expand Down

0 comments on commit ed849ff

Please sign in to comment.