Skip to content

Commit

Permalink
Speed up circuit building by not forgetting about cached objects (qua…
Browse files Browse the repository at this point in the history
…ntumlib#5280)

This speeds up a 10^2 qubit by 10^3 deep creation of a circuit made up entirely of `cirq.X` gates, created by just appending these gates onto the circuit, from 25s to 5s.

The issue is that Moment's are immutable, so they need to be copied when adding in new operations.  Before this PR we don't copy two cached objects, the measurement key objects, and control keys during this.  This copies these caches over and update them.

Because using `insert` on circuit for earliest insertion strategy has to look up measurement keys or control keys (in order to not move an object with such a key before a moment that has such a key), moments during creating are always being asked what their measurement and control keys are.
  • Loading branch information
dabacon authored May 9, 2022
1 parent 2e78e07 commit 75aa90e
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 21 deletions.
7 changes: 4 additions & 3 deletions cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1880,10 +1880,11 @@ def earliest_available_moment(
while k > 0:
k -= 1
moment = self._moments[k]
if moment.operates_on(op_qubits):
return last_available
moment_measurement_keys = protocols.measurement_key_objs(moment)
if (
moment.operates_on(op_qubits)
or not op_measurement_keys.isdisjoint(moment_measurement_keys)
not op_measurement_keys.isdisjoint(moment_measurement_keys)
or not op_control_keys.isdisjoint(moment_measurement_keys)
or not protocols.control_keys(moment).isdisjoint(op_measurement_keys)
):
Expand Down Expand Up @@ -1955,7 +1956,7 @@ def insert(
Moments within the operation tree are inserted intact.
Args:
index: The index to insert all of the operations at.
index: The index to insert all the operations at.
moment_or_operation_tree: The moment or operation tree to insert.
strategy: How to pick/create the moment to put operations into.
Expand Down
44 changes: 26 additions & 18 deletions cirq/circuits/moment.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __init__(self, *contents: 'cirq.OP_TREE') -> None:
self._qubit_to_op[q] = op

self._qubits = frozenset(self._qubit_to_op.keys())
self._measurement_key_objs: Optional[AbstractSet['cirq.MeasurementKey']] = None
self._measurement_key_objs: Optional[FrozenSet['cirq.MeasurementKey']] = None
self._control_keys: Optional[FrozenSet['cirq.MeasurementKey']] = None

@property
Expand Down Expand Up @@ -166,10 +166,13 @@ def with_operation(self, operation: 'cirq.Operation') -> 'cirq.Moment':
# Use private variables to facilitate a quick copy.
m = Moment()
m._operations = self._operations + (operation,)
m._qubits = frozenset(self._qubits.union(set(operation.qubits)))
m._qubit_to_op = self._qubit_to_op.copy()
for q in operation.qubits:
m._qubit_to_op[q] = operation
m._qubits = self._qubits.union(operation.qubits)
m._qubit_to_op = {**self._qubit_to_op, **{q: operation for q in operation.qubits}}

m._measurement_key_objs = self._measurement_key_objs_().union(
protocols.measurement_key_objs(operation)
)
m._control_keys = self._control_keys_().union(protocols.control_keys(operation))

return m

Expand All @@ -185,22 +188,27 @@ def with_operations(self, *contents: 'cirq.OP_TREE') -> 'cirq.Moment':
Raises:
ValueError: If the contents given overlaps a current operation in the moment.
"""
operations = list(self._operations)
flattened_contents = tuple(op_tree.flatten_to_ops(contents))

m = Moment()
# Use private variables to facilitate a quick copy.
m._qubit_to_op = self._qubit_to_op.copy()
qubits = set(self._qubits)
for op in op_tree.flatten_to_ops(contents):
for op in flattened_contents:
if any(q in qubits for q in op.qubits):
raise ValueError(f'Overlapping operations: {op}')
operations.append(op)
qubits.update(op.qubits)

# Use private variables to facilitate a quick copy.
m = Moment()
m._operations = tuple(operations)
m._qubits = frozenset(qubits)
m._qubit_to_op = self._qubit_to_op.copy()
for op in operations:
for q in op.qubits:
m._qubit_to_op[q] = op
m._qubits = frozenset(qubits)

m._operations = self._operations + flattened_contents
m._measurement_key_objs = self._measurement_key_objs_().union(
set(itertools.chain(*(protocols.measurement_key_objs(op) for op in flattened_contents)))
)
m._control_keys = self._control_keys_().union(
set(itertools.chain(*(protocols.control_keys(op) for op in flattened_contents)))
)

return m

Expand Down Expand Up @@ -233,11 +241,11 @@ def _with_measurement_key_mapping_(self, key_map: Dict[str, str]):
def _measurement_key_names_(self) -> AbstractSet[str]:
return {str(key) for key in self._measurement_key_objs_()}

def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']:
def _measurement_key_objs_(self) -> FrozenSet['cirq.MeasurementKey']:
if self._measurement_key_objs is None:
self._measurement_key_objs = {
self._measurement_key_objs = frozenset(
key for op in self.operations for key in protocols.measurement_key_objs(op)
}
)
return self._measurement_key_objs

def _control_keys_(self) -> FrozenSet['cirq.MeasurementKey']:
Expand Down
45 changes: 45 additions & 0 deletions cirq/circuits/moment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,51 @@ def test_measurement_keys():
assert cirq.is_measurement(m2)


def test_measurement_key_objs_caching():
q0, q1, q2, q3 = cirq.LineQubit.range(4)
m = cirq.Moment(cirq.measure(q0, key='foo'))
assert m._measurement_key_objs is None
key_objs = cirq.measurement_key_objs(m)
assert m._measurement_key_objs == key_objs

# Make sure it gets updated when adding an operation.
m = m.with_operation(cirq.measure(q1, key='bar'))
assert m._measurement_key_objs == {
cirq.MeasurementKey(name='bar'),
cirq.MeasurementKey(name='foo'),
}
# Or multiple operations.
m = m.with_operations(cirq.measure(q2, key='doh'), cirq.measure(q3, key='baz'))
assert m._measurement_key_objs == {
cirq.MeasurementKey(name='bar'),
cirq.MeasurementKey(name='foo'),
cirq.MeasurementKey(name='doh'),
cirq.MeasurementKey(name='baz'),
}


def test_control_keys_caching():
q0, q1, q2, q3 = cirq.LineQubit.range(4)
m = cirq.Moment(cirq.X(q0).with_classical_controls('foo'))
assert m._control_keys is None
keys = cirq.control_keys(m)
assert m._control_keys == keys

# Make sure it gets updated when adding an operation.
m = m.with_operation(cirq.X(q1).with_classical_controls('bar'))
assert m._control_keys == {cirq.MeasurementKey(name='bar'), cirq.MeasurementKey(name='foo')}
# Or multiple operations.
m = m.with_operations(
cirq.X(q2).with_classical_controls('doh'), cirq.X(q3).with_classical_controls('baz')
)
assert m._control_keys == {
cirq.MeasurementKey(name='bar'),
cirq.MeasurementKey(name='foo'),
cirq.MeasurementKey(name='doh'),
cirq.MeasurementKey(name='baz'),
}


def test_bool():
assert not cirq.Moment()
a = cirq.NamedQubit('a')
Expand Down

0 comments on commit 75aa90e

Please sign in to comment.