diff --git a/cirq-core/cirq/devices/grid_qubit.py b/cirq-core/cirq/devices/grid_qubit.py index 6344a88cfff..01f77e3d636 100644 --- a/cirq-core/cirq/devices/grid_qubit.py +++ b/cirq-core/cirq/devices/grid_qubit.py @@ -14,6 +14,7 @@ import abc import functools +import weakref from typing import Any, Dict, Iterable, List, Optional, Tuple, Set, TYPE_CHECKING, Union from typing_extensions import Self @@ -34,14 +35,6 @@ class _BaseGridQid(ops.Qid): _dimension: int _hash: Optional[int] = None - def __getstate__(self): - # Don't save hash when pickling; see #3777. - state = self.__dict__ - if "_hash" in state: - state = state.copy() - del state["_hash"] - return state - def __hash__(self) -> int: if self._hash is None: self._hash = hash((self._row, self._col, self._dimension)) @@ -50,7 +43,7 @@ def __hash__(self) -> int: def __eq__(self, other): # Explicitly implemented for performance (vs delegating to Qid). if isinstance(other, _BaseGridQid): - return ( + return self is other or ( self._row == other._row and self._col == other._col and self._dimension == other._dimension @@ -60,7 +53,7 @@ def __eq__(self, other): def __ne__(self, other): # Explicitly implemented for performance (vs delegating to Qid). if isinstance(other, _BaseGridQid): - return ( + return self is not other and ( self._row != other._row or self._col != other._col or self._dimension != other._dimension @@ -178,8 +171,12 @@ class GridQid(_BaseGridQid): cirq.GridQid(5, 4, dimension=2) """ - def __init__(self, row: int, col: int, *, dimension: int) -> None: - """Initializes a grid qid at the given row, col coordinate + # Cache of existing GridQid instances, returned by __new__ if available. + # Holds weak references so instances can still be garbage collected. + _cache = weakref.WeakValueDictionary[Tuple[int, int, int], 'cirq.GridQid']() + + def __new__(cls, row: int, col: int, *, dimension: int) -> 'cirq.GridQid': + """Creates a grid qid at the given row, col coordinate Args: row: the row coordinate @@ -187,13 +184,23 @@ def __init__(self, row: int, col: int, *, dimension: int) -> None: dimension: The dimension of the qid's Hilbert space, i.e. the number of quantum levels. """ - self.validate_dimension(dimension) - self._row = row - self._col = col - self._dimension = dimension + key = (row, col, dimension) + inst = cls._cache.get(key) + if inst is None: + cls.validate_dimension(dimension) + inst = super().__new__(cls) + inst._row = row + inst._col = col + inst._dimension = dimension + cls._cache[key] = inst + return inst + + def __getnewargs_ex__(self): + """Returns a tuple of (args, kwargs) to pass to __new__ when unpickling.""" + return (self._row, self._col), {"dimension": self._dimension} def _with_row_col(self, row: int, col: int) -> 'GridQid': - return GridQid(row, col, dimension=self.dimension) + return GridQid(row, col, dimension=self._dimension) @staticmethod def square(diameter: int, top: int = 0, left: int = 0, *, dimension: int) -> List['GridQid']: @@ -290,16 +297,16 @@ def from_diagram(diagram: str, dimension: int) -> List['GridQid']: return [GridQid(*c, dimension=dimension) for c in coords] def __repr__(self) -> str: - return f"cirq.GridQid({self._row}, {self._col}, dimension={self.dimension})" + return f"cirq.GridQid({self._row}, {self._col}, dimension={self._dimension})" def __str__(self) -> str: - return f"q({self._row}, {self._col}) (d={self.dimension})" + return f"q({self._row}, {self._col}) (d={self._dimension})" def _circuit_diagram_info_( self, args: 'cirq.CircuitDiagramInfoArgs' ) -> 'cirq.CircuitDiagramInfo': return protocols.CircuitDiagramInfo( - wire_symbols=(f"({self._row}, {self._col}) (d={self.dimension})",) + wire_symbols=(f"({self._row}, {self._col}) (d={self._dimension})",) ) def _json_dict_(self) -> Dict[str, Any]: @@ -325,11 +332,31 @@ class GridQubit(_BaseGridQid): _dimension = 2 - def __init__(self, row: int, col: int) -> None: - self._row = row - self._col = col + # Cache of existing GridQubit instances, returned by __new__ if available. + # Holds weak references so instances can still be garbage collected. + _cache = weakref.WeakValueDictionary[Tuple[int, int], 'cirq.GridQubit']() - def _with_row_col(self, row: int, col: int): + def __new__(cls, row: int, col: int) -> 'cirq.GridQubit': + """Creates a grid qubit at the given row, col coordinate + + Args: + row: the row coordinate + col: the column coordinate + """ + key = (row, col) + inst = cls._cache.get(key) + if inst is None: + inst = super().__new__(cls) + inst._row = row + inst._col = col + cls._cache[key] = inst + return inst + + def __getnewargs__(self): + """Returns a tuple of args to pass to __new__ when unpickling.""" + return (self._row, self._col) + + def _with_row_col(self, row: int, col: int) -> 'GridQubit': return GridQubit(row, col) def _cmp_tuple(self): diff --git a/cirq-core/cirq/devices/grid_qubit_test.py b/cirq-core/cirq/devices/grid_qubit_test.py index 2f642806ddd..b6c51f68b39 100644 --- a/cirq-core/cirq/devices/grid_qubit_test.py +++ b/cirq-core/cirq/devices/grid_qubit_test.py @@ -40,11 +40,29 @@ def test_eq(): eq.make_equality_group(lambda: cirq.GridQid(0, 0, dimension=3)) -def test_pickled_hash(): - q = cirq.GridQubit(3, 4) - q_bad = cirq.GridQubit(3, 4) +def test_grid_qubit_pickled_hash(): + # Use a large number that is unlikely to be used by any other tests. + row, col = 123456789, 2345678910 + q_bad = cirq.GridQubit(row, col) + cirq.GridQubit._cache.pop((row, col)) + q = cirq.GridQubit(row, col) + _test_qid_pickled_hash(q, q_bad) + + +def test_grid_qid_pickled_hash(): + # Use a large number that is unlikely to be used by any other tests. + row, col = 123456789, 2345678910 + q_bad = cirq.GridQid(row, col, dimension=3) + cirq.GridQid._cache.pop((row, col, 3)) + q = cirq.GridQid(row, col, dimension=3) + _test_qid_pickled_hash(q, q_bad) + + +def _test_qid_pickled_hash(q: 'cirq.Qid', q_bad: 'cirq.Qid') -> None: + """Test that hashes are not pickled with Qid instances.""" + assert q_bad is not q _ = hash(q_bad) # compute hash to ensure it is cached. - q_bad._hash = q_bad._hash + 1 + q_bad._hash = q_bad._hash + 1 # type: ignore[attr-defined] assert q_bad == q assert hash(q_bad) != hash(q) data = pickle.dumps(q_bad) diff --git a/cirq-core/cirq/devices/line_qubit.py b/cirq-core/cirq/devices/line_qubit.py index 2f9bf6a6bca..71d292b5e3d 100644 --- a/cirq-core/cirq/devices/line_qubit.py +++ b/cirq-core/cirq/devices/line_qubit.py @@ -14,7 +14,8 @@ import abc import functools -from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, TYPE_CHECKING, Union +import weakref +from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple, TYPE_CHECKING, Union from typing_extensions import Self from cirq import ops, protocols @@ -31,14 +32,6 @@ class _BaseLineQid(ops.Qid): _dimension: int _hash: Optional[int] = None - def __getstate__(self): - # Don't save hash when pickling; see #3777. - state = self.__dict__ - if "_hash" in state: - state = state.copy() - del state["_hash"] - return state - def __hash__(self) -> int: if self._hash is None: self._hash = hash((self._x, self._dimension)) @@ -47,13 +40,15 @@ def __hash__(self) -> int: def __eq__(self, other): # Explicitly implemented for performance (vs delegating to Qid). if isinstance(other, _BaseLineQid): - return self._x == other._x and self._dimension == other._dimension + return self is other or (self._x == other._x and self._dimension == other._dimension) return NotImplemented def __ne__(self, other): # Explicitly implemented for performance (vs delegating to Qid). if isinstance(other, _BaseLineQid): - return self._x != other._x or self._dimension != other._dimension + return self is not other and ( + self._x != other._x or self._dimension != other._dimension + ) return NotImplemented def _comparison_key(self): @@ -154,7 +149,11 @@ class LineQid(_BaseLineQid): """ - def __init__(self, x: int, dimension: int) -> None: + # Cache of existing LineQid instances, returned by __new__ if available. + # Holds weak references so instances can still be garbage collected. + _cache = weakref.WeakValueDictionary[Tuple[int, int], 'cirq.LineQid']() + + def __new__(cls, x: int, dimension: int) -> 'cirq.LineQid': """Initializes a line qid at the given x coordinate. Args: @@ -162,9 +161,19 @@ def __init__(self, x: int, dimension: int) -> None: dimension: The dimension of the qid's Hilbert space, i.e. the number of quantum levels. """ - self.validate_dimension(dimension) - self._x = x - self._dimension = dimension + key = (x, dimension) + inst = cls._cache.get(key) + if inst is None: + cls.validate_dimension(dimension) + inst = super().__new__(cls) + inst._x = x + inst._dimension = dimension + cls._cache[key] = inst + return inst + + def __getnewargs__(self): + """Returns a tuple of args to pass to __new__ when unpickling.""" + return (self._x, self._dimension) def _with_x(self, x: int) -> 'LineQid': return LineQid(x, dimension=self._dimension) @@ -246,13 +255,26 @@ class LineQubit(_BaseLineQid): _dimension = 2 - def __init__(self, x: int) -> None: - """Initializes a line qubit at the given x coordinate. + # Cache of existing LineQubit instances, returned by __new__ if available. + # Holds weak references so instances can still be garbage collected. + _cache = weakref.WeakValueDictionary[int, 'cirq.LineQubit']() + + def __new__(cls, x: int) -> 'cirq.LineQubit': + """Initializes a line qid at the given x coordinate. Args: x: The x coordinate. """ - self._x = x + inst = cls._cache.get(x) + if inst is None: + inst = super().__new__(cls) + inst._x = x + cls._cache[x] = inst + return inst + + def __getnewargs__(self): + """Returns a tuple of args to pass to __new__ when unpickling.""" + return (self._x,) def _with_x(self, x: int) -> 'LineQubit': return LineQubit(x) diff --git a/cirq-core/cirq/devices/line_qubit_test.py b/cirq-core/cirq/devices/line_qubit_test.py index 6c8474f313a..85d7ecd73a8 100644 --- a/cirq-core/cirq/devices/line_qubit_test.py +++ b/cirq-core/cirq/devices/line_qubit_test.py @@ -15,6 +15,7 @@ import pytest import cirq +from cirq.devices.grid_qubit_test import _test_qid_pickled_hash def test_init(): @@ -67,6 +68,24 @@ def test_cmp_failure(): _ = cirq.LineQid(1, 3) < 0 +def test_line_qubit_pickled_hash(): + # Use a large number that is unlikely to be used by any other tests. + x = 1234567891011 + q_bad = cirq.LineQubit(x) + cirq.LineQubit._cache.pop(x) + q = cirq.LineQubit(x) + _test_qid_pickled_hash(q, q_bad) + + +def test_line_qid_pickled_hash(): + # Use a large number that is unlikely to be used by any other tests. + x = 1234567891011 + q_bad = cirq.LineQid(x, dimension=3) + cirq.LineQid._cache.pop((x, 3)) + q = cirq.LineQid(x, dimension=3) + _test_qid_pickled_hash(q, q_bad) + + def test_is_adjacent(): assert cirq.LineQubit(1).is_adjacent(cirq.LineQubit(2)) assert cirq.LineQubit(1).is_adjacent(cirq.LineQubit(0)) diff --git a/cirq-core/cirq/ops/named_qubit.py b/cirq-core/cirq/ops/named_qubit.py index 6024c91dcb7..7f7bccaf516 100644 --- a/cirq-core/cirq/ops/named_qubit.py +++ b/cirq-core/cirq/ops/named_qubit.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools -from typing import Any, Dict, List, Optional, TYPE_CHECKING +import weakref +from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING from cirq import protocols from cirq.ops import raw_types @@ -31,17 +32,6 @@ class _BaseNamedQid(raw_types.Qid): _comp_key: Optional[str] = None _hash: Optional[int] = None - def __getstate__(self): - # Don't save hash when pickling; see #3777. - state = self.__dict__ - if "_hash" in state or "_comp_key" in state: - state = state.copy() - if "_hash" in state: - del state["_hash"] - if "_comp_key" in state: - del state["_comp_key"] - return state - def __hash__(self) -> int: if self._hash is None: self._hash = hash((self._name, self._dimension)) @@ -50,13 +40,17 @@ def __hash__(self) -> int: def __eq__(self, other): # Explicitly implemented for performance (vs delegating to Qid). if isinstance(other, _BaseNamedQid): - return self._name == other._name and self._dimension == other._dimension + return self is other or ( + self._name == other._name and self._dimension == other._dimension + ) return NotImplemented def __ne__(self, other): # Explicitly implemented for performance (vs delegating to Qid). if isinstance(other, _BaseNamedQid): - return self._name != other._name or self._dimension != other._dimension + return self is not other and ( + self._name != other._name or self._dimension != other._dimension + ) return NotImplemented def _comparison_key(self): @@ -86,7 +80,11 @@ class NamedQid(_BaseNamedQid): correctly come before 'qid22'. """ - def __init__(self, name: str, dimension: int) -> None: + # Cache of existing NamedQid instances, returned by __new__ if available. + # Holds weak references so instances can still be garbage collected. + _cache = weakref.WeakValueDictionary[Tuple[str, int], 'cirq.NamedQid']() + + def __new__(cls, name: str, dimension: int) -> 'cirq.NamedQid': """Initializes a `NamedQid` with a given name and dimension. Args: @@ -94,9 +92,19 @@ def __init__(self, name: str, dimension: int) -> None: dimension: The dimension of the qid's Hilbert space, i.e. the number of quantum levels. """ - self.validate_dimension(dimension) - self._name = name - self._dimension = dimension + key = (name, dimension) + inst = cls._cache.get(key) + if inst is None: + cls.validate_dimension(dimension) + inst = super().__new__(cls) + inst._name = name + inst._dimension = dimension + cls._cache[key] = inst + return inst + + def __getnewargs__(self): + """Returns a tuple of args to pass to __new__ when unpickling.""" + return (self._name, self._dimension) def __repr__(self) -> str: return f'cirq.NamedQid({self._name!r}, dimension={self._dimension})' @@ -143,13 +151,28 @@ class NamedQubit(_BaseNamedQid): _dimension = 2 - def __init__(self, name: str) -> None: - """Initializes a `NamedQubit` with a given name. + # Cache of existing NamedQubit instances, returned by __new__ if available. + # Holds weak references so instances can still be garbage collected. + _cache = weakref.WeakValueDictionary[str, 'cirq.NamedQubit']() + + def __new__(cls, name: str) -> 'cirq.NamedQubit': + """Initializes a `NamedQid` with a given name and dimension. Args: name: The name. + dimension: The dimension of the qid's Hilbert space, i.e. + the number of quantum levels. """ - self._name = name + inst = cls._cache.get(name) + if inst is None: + inst = super().__new__(cls) + inst._name = name + cls._cache[name] = inst + return inst + + def __getnewargs__(self): + """Returns a tuple of args to pass to __new__ when unpickling.""" + return (self._name,) def _cmp_tuple(self): cls = NamedQid if type(self) is NamedQubit else type(self) diff --git a/cirq-core/cirq/ops/named_qubit_test.py b/cirq-core/cirq/ops/named_qubit_test.py index 12611b16cab..6e8b79354fa 100644 --- a/cirq-core/cirq/ops/named_qubit_test.py +++ b/cirq-core/cirq/ops/named_qubit_test.py @@ -13,6 +13,7 @@ # limitations under the License. import cirq +from cirq.devices.grid_qubit_test import _test_qid_pickled_hash from cirq.ops.named_qubit import _pad_digits @@ -41,6 +42,24 @@ def test_named_qubit_repr(): assert repr(qid) == "cirq.NamedQid('a', dimension=3)" +def test_named_qubit_pickled_hash(): + # Use a name that is unlikely to be used by any other tests. + x = "test_named_qubit_pickled_hash" + q_bad = cirq.NamedQubit(x) + cirq.NamedQubit._cache.pop(x) + q = cirq.NamedQubit(x) + _test_qid_pickled_hash(q, q_bad) + + +def test_named_qid_pickled_hash(): + # Use a name that is unlikely to be used by any other tests. + x = "test_named_qid_pickled_hash" + q_bad = cirq.NamedQid(x, dimension=3) + cirq.NamedQid._cache.pop((x, 3)) + q = cirq.NamedQid(x, dimension=3) + _test_qid_pickled_hash(q, q_bad) + + def test_named_qubit_order(): order = cirq.testing.OrderTester() order.add_ascending(