Skip to content

Commit

Permalink
Cache the hash of frozen circuits to avoid recomputing (quantumlib#5738)
Browse files Browse the repository at this point in the history
This follows what we previously did for GridQubit, but reworked to use `_compat.cached_method` and to compute the hash lazily.
  • Loading branch information
maffoo authored Jul 12, 2022
1 parent 8de6868 commit 04c02aa
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 15 deletions.
13 changes: 12 additions & 1 deletion cirq/circuits/frozen_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,20 @@ def __init__(
def moments(self) -> Sequence['cirq.Moment']:
return self._moments

def __hash__(self):
@_compat.cached_method
def __hash__(self) -> int:
# Explicitly cached for performance
return hash((self.moments,))

def __getstate__(self):
# Don't save hash when pickling; see #3777.
state = self.__dict__
hash_cache = _compat._method_cache_name(self.__hash__)
if hash_cache in state:
state = state.copy()
del state[hash_cache]
return state

@_compat.cached_method
def _num_qubits_(self) -> int:
return len(self.all_qubits())
Expand Down
22 changes: 9 additions & 13 deletions cirq/devices/grid_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import numpy as np

from cirq import ops, protocols
from cirq import _compat, ops, protocols

if TYPE_CHECKING:
import cirq
Expand Down Expand Up @@ -294,23 +294,19 @@ class GridQubit(_BaseGridQid):
cirq.GridQubit(5, 4)
"""

def __init__(self, row: int, col: int):
super().__init__(row, col)
self._hash = super().__hash__()

def __getstate__(self):
# Don't save hash when pickling; see #3777.
state = self.__dict__.copy()
del state['_hash']
state = self.__dict__
hash_key = _compat._method_cache_name(self.__hash__)
if hash_key in state:
state = state.copy()
del state[hash_key]
return state

def __setstate__(self, state):
self.__dict__.update(state)
self._hash = super().__hash__()

def __hash__(self):
@_compat.cached_method
def __hash__(self) -> int:
# Explicitly cached for performance (vs delegating to Qid).
return self._hash
return super().__hash__()

def __eq__(self, other):
# Explicitly implemented for performance (vs delegating to Qid).
Expand Down
5 changes: 4 additions & 1 deletion cirq/devices/grid_qubit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pytest

import cirq
from cirq import _compat


def test_init():
Expand All @@ -43,7 +44,9 @@ def test_eq():
def test_pickled_hash():
q = cirq.GridQubit(3, 4)
q_bad = cirq.GridQubit(3, 4)
q_bad._hash += 1
_ = hash(q_bad) # compute hash to ensure it is cached.
hash_key = _compat._method_cache_name(cirq.GridQubit.__hash__)
setattr(q_bad, hash_key, getattr(q_bad, hash_key) + 1)
assert q_bad == q
assert hash(q_bad) != hash(q)
data = pickle.dumps(q_bad)
Expand Down

0 comments on commit 04c02aa

Please sign in to comment.