From ed849ff9e0576dcee2514741b5ffa951a94e20a8 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 11 Sep 2024 08:54:08 -0700 Subject: [PATCH] Make sure to call the superclass' __init__() on a newly created instance in PositionalSharding._remake(). If we don't do this, the C++ base class is left in an uninitialized state, leading to failures elsewhere in the test suite. PiperOrigin-RevId: 673411282 --- jax/_src/sharding_impls.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 0b1dc082765e..add297b6a351 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -691,13 +691,9 @@ def check_compatible_aval(self, aval_shape: Shape) -> None: def _remake( cls, devices: tuple[xc.Device, ...], ids: np.ndarray, *, memory_kind: str | None = None) -> PositionalSharding: - self = cls.__new__(cls) - self._devices = devices - self._ids = ids - self._internal_device_list = xc.DeviceList(self._devices) - self._memory_kind = xc.check_and_canonicalize_memory_kind( - memory_kind, self._internal_device_list) - return self + sharding = cls(devices, memory_kind=memory_kind) + sharding._ids = ids + return sharding # Hashable