Skip to content

Commit

Permalink
Qubit manager prototype without global state using cirq.decompose
Browse files Browse the repository at this point in the history
  • Loading branch information
tanujkhattar committed Apr 3, 2023
1 parent f114996 commit f05f7d7
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 29 deletions.
1 change: 1 addition & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@
qft,
Qid,
QuantumFourierTransformGate,
QubitManager,
QubitOrder,
QubitOrderOrList,
QubitPermutationGate,
Expand Down
2 changes: 2 additions & 0 deletions cirq-core/cirq/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@

from cirq.ops.qid_util import q

from cirq.ops.qubit_manager import QubitManager, GreedyQubitManager, SimpleQubitManager

from cirq.ops.random_gate_channel import RandomGateChannel

from cirq.ops.raw_types import Gate, Operation, Qid, TaggedOperation
Expand Down
5 changes: 5 additions & 0 deletions cirq-core/cirq/ops/gate_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,11 @@ def _num_qubits_(self):
def _decompose_(self) -> 'cirq.OP_TREE':
return protocols.decompose_once_with_qubits(self.gate, self.qubits, NotImplemented)

def _decompose_with_qubit_manager_(self, qubit_manager: 'cirq.QubitManager') -> 'cirq.OP_TREE':
return protocols.decompose_once_with_qubits(
self.gate, self.qubits, NotImplemented, qubit_manager=qubit_manager
)

def _pauli_expansion_(self) -> value.LinearDict[str]:
getter = getattr(self.gate, '_pauli_expansion_', None)
if getter is not None:
Expand Down
89 changes: 60 additions & 29 deletions cirq-core/cirq/protocols/decompose_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect

from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -45,7 +47,10 @@
RaiseTypeErrorIfNotProvided: Any = ([],)

DecomposeResult = Union[None, NotImplementedType, 'cirq.OP_TREE']
OpDecomposer = Callable[['cirq.Operation'], DecomposeResult]
OpDecomposer = Union[
Callable[['cirq.Operation'], DecomposeResult],
Callable[['cirq.Operation', 'cirq.QubitManager'], DecomposeResult],
]

DECOMPOSE_TARGET_GATESET = ops.Gateset(
ops.XPowGate,
Expand Down Expand Up @@ -128,6 +133,13 @@ def _decompose_(self, qubits: Tuple['cirq.Qid', ...]) -> DecomposeResult:
pass


class SupportsDecomposeWithQubitManager(Protocol):
def _decompose_with_qubit_manager_(
self, qubits: Tuple['cirq.Qid', ...], qubit_manager: 'cirq.QubitManager'
) -> DecomposeResult:
pass


def decompose(
val: Any,
*,
Expand All @@ -138,6 +150,7 @@ def decompose(
None, Exception, Callable[['cirq.Operation'], Optional[Exception]]
] = _value_error_describing_bad_operation,
preserve_structure: bool = False,
qubit_manager: Optional['cirq.QubitManager'] = None,
) -> List['cirq.Operation']:
"""Recursively decomposes a value into `cirq.Operation`s meeting a criteria.
Expand Down Expand Up @@ -200,18 +213,24 @@ def decompose(
"acceptable to keep."
)

if qubit_manager is None:
qubit_manager = ops.SimpleQubitManager()

if preserve_structure:
return _decompose_preserving_structure(
val,
intercepting_decomposer=intercepting_decomposer,
fallback_decomposer=fallback_decomposer,
keep=keep,
on_stuck_raise=on_stuck_raise,
qubit_manager=qubit_manager,
)

def try_op_decomposer(val: Any, decomposer: Optional[OpDecomposer]) -> DecomposeResult:
if decomposer is None or not isinstance(val, ops.Operation):
return None
if 'qubit_manager' in inspect.signature(decomposer).parameters:
return decomposer(val, qubit_manager=qubit_manager)
return decomposer(val)

output = []
Expand All @@ -225,7 +244,7 @@ def try_op_decomposer(val: Any, decomposer: Optional[OpDecomposer]) -> Decompose
decomposed = try_op_decomposer(item, intercepting_decomposer)

if decomposed is NotImplemented or decomposed is None:
decomposed = decompose_once(item, default=None)
decomposed = decompose_once(item, default=None, qubit_manager=qubit_manager)

if decomposed is NotImplemented or decomposed is None:
decomposed = try_op_decomposer(item, fallback_decomposer)
Expand Down Expand Up @@ -295,11 +314,12 @@ def decompose_once(val: Any, default=RaiseTypeErrorIfNotProvided, *args, **kwarg
TypeError: `val` didn't have a `_decompose_` method (or that method returned
`NotImplemented` or `None`) and `default` wasn't set.
"""
method = getattr(val, '_decompose_', None)
decomposed = NotImplemented if method is None else method(*args, **kwargs)

if decomposed is not NotImplemented and decomposed is not None:
return list(ops.flatten_op_tree(decomposed))
for strat in ['_decompose_with_qubit_manager_', '_decompose_']:
method = getattr(val, strat, None)
decomposed = NotImplemented if method is None else method(*args, **kwargs)
if decomposed is not NotImplemented and decomposed is not None:
return list(ops.flatten_op_tree(decomposed))
kwargs.pop('qubit_manager', None)

if default is not RaiseTypeErrorIfNotProvided:
return default
Expand All @@ -318,13 +338,19 @@ def decompose_once_with_qubits(val: Any, qubits: Iterable['cirq.Qid']) -> List['

@overload
def decompose_once_with_qubits(
val: Any, qubits: Iterable['cirq.Qid'], default: Optional[TDefault]
val: Any,
qubits: Iterable['cirq.Qid'],
default: Optional[TDefault],
qubit_manager: Optional['cirq.QubitManager'],
) -> Union[TDefault, List['cirq.Operation']]:
pass


def decompose_once_with_qubits(
val: Any, qubits: Iterable['cirq.Qid'], default=RaiseTypeErrorIfNotProvided
val: Any,
qubits: Iterable['cirq.Qid'],
default=RaiseTypeErrorIfNotProvided,
qubit_manager: Optional['cirq.QubitManager'] = None,
):
"""Decomposes a value into operations on the given qubits.
Expand Down Expand Up @@ -352,38 +378,35 @@ def decompose_once_with_qubits(
`val` didn't have a `_decompose_` method (or that method returned
`NotImplemented` or `None`) and `default` wasn't set.
"""
return decompose_once(val, default, tuple(qubits))
return decompose_once(val, default, tuple(qubits), qubit_manager=qubit_manager)


# pylint: enable=function-redefined


def _try_decompose_into_operations_and_qubits(
val: Any,
val: Any, qubit_manager: Optional['cirq.QubitManager'] = None
) -> Tuple[Optional[List['cirq.Operation']], Sequence['cirq.Qid'], Tuple[int, ...]]:
"""Returns the value's decomposition (if any) and the qubits it applies to."""
from cirq.circuits import FrozenCircuit

if qubit_manager is None:
qubit_manager = ops.SimpleQubitManager()
qubits: Sequence[cirq.Qid] = []
if isinstance(val, ops.Gate):
# Gates don't specify qubits, and so must be handled specially.
qid_shape = qid_shape_protocol.qid_shape(val)
qubits: Sequence[cirq.Qid] = devices.LineQid.for_qid_shape(qid_shape)
return decompose_once_with_qubits(val, qubits, None), qubits, qid_shape

if isinstance(val, ops.Operation):
qid_shape = qid_shape_protocol.qid_shape(val)
return decompose_once(val, None), val.qubits, qid_shape

result = decompose_once(val, None)
if result is not None:
qubit_set = set()
qid_shape_dict: Dict[cirq.Qid, int] = defaultdict(lambda: 1)
for op in result:
for level, q in zip(qid_shape_protocol.qid_shape(op), op.qubits):
qubit_set.add(q)
qid_shape_dict[q] = max(qid_shape_dict[q], level)
qubits = sorted(qubit_set)
return result, qubits, tuple(qid_shape_dict[q] for q in qubits)

qubits = devices.LineQid.for_qid_shape(qid_shape)
decomposed = decompose_once_with_qubits(val, qubits, None, qubit_manager=qubit_manager)
elif isinstance(val, ops.Operation):
decomposed = decompose_once(val, None, qubit_manager=qubit_manager)
qubits = val.qubits
else:
decomposed = decompose_once(val, None, qubit_manager=qubit_manager)

if decomposed is not None:
qubits = sorted(FrozenCircuit(decomposed, ops.I.on_each(*qubits)).all_qubits())
return (decomposed, qubits, qid_shape_protocol.qid_shape(qubits))
return None, (), ()


Expand All @@ -396,6 +419,7 @@ def _decompose_preserving_structure(
on_stuck_raise: Union[
None, Exception, Callable[['cirq.Operation'], Optional[Exception]]
] = _value_error_describing_bad_operation,
qubit_manager: Optional['cirq.QubitManager'] = None,
) -> List['cirq.Operation']:
"""Preserves structure (e.g. subcircuits) while decomposing ops.
Expand All @@ -419,10 +443,15 @@ def keep_structure(op: 'cirq.Operation'):
if keep is not None and keep(op):
return True

if qubit_manager is None:
qubit_manager = ops.SimpleQubitManager()

def dps_interceptor(op: 'cirq.Operation'):
if not isinstance(op.untagged, CircuitOperation):
if intercepting_decomposer is None:
return NotImplemented
if 'qubit_manager' in inspect.signature(intercepting_decomposer).parameters:
return intercepting_decomposer(op, qubit_manager=qubit_manager)
return intercepting_decomposer(op)

new_fc = FrozenCircuit(
Expand All @@ -432,6 +461,7 @@ def dps_interceptor(op: 'cirq.Operation'):
fallback_decomposer=fallback_decomposer,
keep=keep_structure,
on_stuck_raise=on_stuck_raise,
qubit_manager=qubit_manager,
)
)
visited_fcs.add(new_fc)
Expand All @@ -444,4 +474,5 @@ def dps_interceptor(op: 'cirq.Operation'):
fallback_decomposer=fallback_decomposer,
keep=keep_structure,
on_stuck_raise=on_stuck_raise,
qubit_manager=qubit_manager,
)
8 changes: 8 additions & 0 deletions cirq-core/cirq/protocols/decompose_protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ def _decompose_(self, qubits):
return self.func(*qubits)


class DecomposeWithQubitManagerGiven:
def __init__(self, func):
self.func = func

def _decompose_with_qubit_manager_(self, qubits, qubit_manager):
return self.func(*qubits, qubit_manager=qubit_manager)


class DecomposeGenerated:
def _decompose_(self):
yield cirq.X(cirq.LineQubit(0))
Expand Down
3 changes: 3 additions & 0 deletions cirq-core/cirq/protocols/unitary_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,4 +190,7 @@ def _strat_unitary_from_decompose(val: Any) -> Optional[np.ndarray]:
if result is None:
return None
state_len = np.prod(val_qid_shape, dtype=np.int64)
# TODO: Can we use linear algebra here to verify that the newly allocated
# ancilla are correctly "freed" at the end of the function? And if the resulting
# effect on system register is unitary, then return the reduced unitary?
return result.reshape((state_len, state_len))
25 changes: 25 additions & 0 deletions cirq-core/cirq/protocols/unitary_protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,31 @@ def test_decompose_and_get_unitary():
np.testing.assert_allclose(_strat_unitary_from_decompose(OtherComposite()), m2)


@pytest.mark.parametrize('exp', [-1, -0.8, -0.5, 0, 0.5, 0.8, 1])
def test_ancilla(exp):
class AncillaX(cirq.Gate):
def __init__(self, exponent=1):
self._exponent = exponent

def num_qubits(self) -> int:
return 1

def _decompose_(self, qubits):
ancilla = cirq.ops.qubit_manager.CleanQubit(1)
# yield cirq.X(qubits[0]) ** self._exponent
yield cirq.CX(qubits[0], ancilla)
yield cirq.Z(ancilla) ** self._exponent
yield cirq.CX(qubits[0], ancilla)

q = cirq.LineQubit(0)
gate_u = cirq.unitary(AncillaX(exp))[0:2, 0:2]
op_u = cirq.unitary(AncillaX(exp).on(q))[0:2, 0:2]
exp_u = cirq.unitary(cirq.Z(q) ** exp)
np.testing.assert_allclose(op_u, gate_u)
np.testing.assert_allclose(gate_u, exp_u)
np.testing.assert_allclose(op_u, exp_u)


def test_decomposed_has_unitary():
# Gates
assert cirq.has_unitary(DecomposableGate(True))
Expand Down

0 comments on commit f05f7d7

Please sign in to comment.