Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix in Cirq Interop: Attempt 2 #1100

Merged
merged 8 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions qualtran/_infra/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import cirq
from attrs import frozen
from numpy.typing import NDArray

from .composite_bloq import _binst_to_cxns, _cxns_to_soq_dict, _map_soqs, _reg_to_soq, BloqBuilder
from .gate_with_registers import GateWithRegisters
Expand Down Expand Up @@ -142,13 +141,6 @@ def decompose_bloq(self) -> 'CompositeBloq':
"""The decomposition is the adjoint of `subbloq`'s decomposition."""
return self.subbloq.decompose_bloq().adjoint()

def decompose_from_registers(
self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] # type: ignore[type-var]
) -> cirq.OP_TREE:
if isinstance(self.subbloq, GateWithRegisters):
return cirq.inverse(self.subbloq.decompose_from_registers(context=context, **quregs))
return super().decompose_from_registers(context=context, **quregs)

def _circuit_diagram_info_(
self, args: 'cirq.CircuitDiagramInfoArgs'
) -> cirq.CircuitDiagramInfo:
Expand Down
6 changes: 3 additions & 3 deletions qualtran/_infra/gate_with_registers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,11 @@ def test_gate_with_registers_decompose_from_context_auto_generated():
cirq.testing.assert_has_diagram(
circuit,
"""
l: ───BloqWithDecompose───X───────free───
l: ───BloqWithDecompose───X───
r: ───r───────────────────alloc───Z──────
r: ───r───────────────────Z───
t: ───t───────────────────Y──────────────
t: ───t───────────────────Y───
""",
)

Expand Down
3 changes: 3 additions & 0 deletions qualtran/bloqs/arithmetic/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,9 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
def _has_unitary_(self):
return True

def adjoint(self) -> 'Bloq':
return self


@bloq_example
def _leq_symb() -> LessThanEqual:
Expand Down
3 changes: 2 additions & 1 deletion qualtran/bloqs/basic_gates/s_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def as_cirq_op(
import cirq

(q,) = q
return cirq.S(q), {'q': np.array([q])}
p = -1 if self.is_adjoint else 1
return cirq.S(q) ** p, {'q': np.array([q])}

def pretty_name(self) -> str:
maybe_dag = '†' if self.is_adjoint else ''
Expand Down
3 changes: 2 additions & 1 deletion qualtran/bloqs/basic_gates/s_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ def test_to_cirq():
bb = BloqBuilder()
q = bb.add(PlusState())
q = bb.add(SGate(), q=q)
q = bb.add(SGate().adjoint(), q=q)
cbloq = bb.finalize(q=q)
circuit = cbloq.to_cirq_circuit()
cirq.testing.assert_has_diagram(circuit, "_c(0): ───H───S───")
cirq.testing.assert_has_diagram(circuit, "_c(0): ───H───S───S^-1───")


def test_tensors():
Expand Down
15 changes: 14 additions & 1 deletion qualtran/bloqs/bookkeeping/allocate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import cached_property
from typing import Dict, List, Tuple, TYPE_CHECKING
from typing import Dict, List, Tuple, TYPE_CHECKING, Union

import numpy as np
import sympy
from attrs import frozen

Expand All @@ -34,8 +35,11 @@
from qualtran.drawing import directional_text_box, Text, WireSymbol

if TYPE_CHECKING:
import cirq
import quimb.tensor as qtn

from qualtran.cirq_interop import CirqQuregT


@frozen
class Allocate(_BookkeepingBloq):
Expand Down Expand Up @@ -83,6 +87,15 @@ def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> 'WireSym
assert reg.name == 'reg'
return directional_text_box('alloc', Side.RIGHT)

def as_cirq_op(
self, qubit_manager: 'cirq.QubitManager'
) -> Tuple[Union['cirq.Operation', None], Dict[str, 'CirqQuregT']]:
shape = (*self.signature[0].shape, self.signature[0].bitsize)
return (
None,
{'reg': np.array(qubit_manager.qalloc(self.signature.n_qubits())).reshape(shape)},
)


@bloq_example
def _alloc() -> Allocate:
Expand Down
10 changes: 9 additions & 1 deletion qualtran/bloqs/bookkeeping/free.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from functools import cached_property
from typing import Dict, List, Tuple, TYPE_CHECKING
from typing import Dict, List, Tuple, TYPE_CHECKING, Union

import sympy
from attrs import frozen
Expand All @@ -35,8 +35,10 @@
from qualtran.drawing import directional_text_box, Text, WireSymbol

if TYPE_CHECKING:
import cirq
import quimb.tensor as qtn

from qualtran.cirq_interop import CirqQuregT
from qualtran.simulation.classical_sim import ClassicalValT


Expand Down Expand Up @@ -92,6 +94,12 @@ def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> 'WireSym
assert reg.name == 'reg'
return directional_text_box('free', Side.LEFT)

def as_cirq_op(
self, qubit_manager: 'cirq.QubitManager', reg: 'CirqQuregT'
) -> Tuple[Union['cirq.Operation', None], Dict[str, 'CirqQuregT']]:
qubit_manager.qfree(reg.flatten().tolist())
return (None, {})


@bloq_example
def _free() -> Free:
Expand Down
6 changes: 3 additions & 3 deletions qualtran/bloqs/data_loading/select_swap_qrom_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ def test_select_swap_qrom(data, block_size):
cirq.decompose_once(qrom.on_registers(**qubit_regs), context=context)
)

dirty_target_ancilla = [
q for q in qrom_circuit.all_qubits() if isinstance(q, cirq.ops.BorrowableQubit)
]
dirty_target_ancilla = sorted(
qrom_circuit.all_qubits() - set(q for qs in qubit_regs.values() for q in qs.flatten())
)

circuit = cirq.Circuit(
# Prepare dirty ancillas in an arbitrary state.
Expand Down
4 changes: 2 additions & 2 deletions qualtran/bloqs/reflections/reflection_using_prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@
from qualtran.bloqs.basic_gates.global_phase import GlobalPhase
from qualtran.bloqs.basic_gates.rotation import ZPowGate
from qualtran.bloqs.basic_gates.x_basis import XGate
from qualtran.bloqs.block_encoding.lcu_select_and_prepare import PrepareOracle
from qualtran.bloqs.mcmt.multi_control_multi_target_pauli import MultiControlPauli
from qualtran.resource_counting.generalizers import ignore_split_join
from qualtran.symbolics.types import SymbolicInt

if TYPE_CHECKING:
from qualtran.bloqs.block_encoding.lcu_select_and_prepare import PrepareOracle
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator


Expand Down Expand Up @@ -77,7 +77,7 @@ class ReflectionUsingPrepare(SpecializedSingleQubitControlledGate):
Babbush et. al. (2018). Figure 1.
"""

prepare_gate: PrepareOracle
prepare_gate: 'PrepareOracle'
control_val: Optional[int] = None
global_phase: complex = 1
eps: float = 1e-11
Expand Down
12 changes: 10 additions & 2 deletions qualtran/bloqs/reflections/reflection_using_prepare_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import pytest
from numpy.typing import NDArray

from qualtran import Bloq
from qualtran import Adjoint, Bloq
from qualtran._infra.gate_with_registers import get_named_qubits
from qualtran.bloqs.arithmetic import LessThanConstant, LessThanEqual
from qualtran.bloqs.basic_gates import ZPowGate
Expand All @@ -30,6 +30,7 @@
from qualtran.bloqs.reflections.prepare_identity import PrepareIdentity
from qualtran.bloqs.reflections.reflection_using_prepare import ReflectionUsingPrepare
from qualtran.bloqs.state_preparation import StatePreparationAliasSampling
from qualtran.cirq_interop import BloqAsCirqGate
from qualtran.cirq_interop.testing import GateHelper
from qualtran.resource_counting.generalizers import (
ignore_alloc_free,
Expand Down Expand Up @@ -58,6 +59,13 @@ def keep(op: cirq.Operation):
ret = op in gateset_to_keep
if op.gate is not None and isinstance(op.gate, cirq.ops.raw_types._InverseCompositeGate):
ret |= op.gate._original in gateset_to_keep
if op.gate is not None and isinstance(op.gate, Adjoint):
subgate = (
op.gate.subbloq
if isinstance(op.gate.subbloq, cirq.Gate)
else BloqAsCirqGate(op.gate.subbloq)
)
ret |= subgate in gateset_to_keep
return ret


Expand All @@ -73,7 +81,7 @@ def construct_gate_helper_and_qubit_order(gate, decompose_once: bool = False):
)
ordered_input = list(itertools.chain(*g.quregs.values()))
qubit_order = cirq.QubitOrder.explicit(ordered_input, fallback=cirq.QubitOrder.DEFAULT)
assert len(circuit.all_qubits()) < 30
assert len(circuit.all_qubits()) < 24
return g, qubit_order, circuit


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,16 @@ def test_prepare_uniform_superposition_consistent_protocols():
PrepareUniformSuperposition(5, cvs=()),
PrepareUniformSuperposition(5, cvs=[]),
)


def test_prepare_uniform_superposition_adjoint():
n = 3
target = cirq.NamedQubit.range((n - 1).bit_length(), prefix='target')
control = [cirq.NamedQubit('control')]
op = PrepareUniformSuperposition(n, cvs=(0,)).on_registers(ctrl=control, target=target)
gqm = cirq.GreedyQubitManager(prefix="_ancilla", maximize_reuse=True)
context = cirq.DecompositionContext(gqm)
circuit = cirq.Circuit(op, cirq.decompose(cirq.inverse(op), context=context))
identity = cirq.Circuit(cirq.identity_each(*circuit.all_qubits())).final_state_vector()
result = cirq.Simulator(dtype=np.complex128).simulate(circuit)
np.testing.assert_allclose(result.final_state_vector, identity, atol=1e-8)
14 changes: 13 additions & 1 deletion qualtran/bloqs/swap_network/cswap_approx_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random
from typing import Dict, Tuple, Union

import cirq
import numpy as np
import pytest
import sympy

Expand All @@ -38,6 +39,17 @@ def test_cswap_approx_decomp():
assert_valid_bloq_decomposition(csa)


def test_cswap_approx_decomposition():
csa = CSwapApprox(4)
circuit = (
csa.as_composite_bloq().to_cirq_circuit()
+ csa.adjoint().as_composite_bloq().to_cirq_circuit()
)
initial_state = cirq.testing.random_superposition(2**9, random_state=1234)
result = cirq.Simulator(dtype=np.complex128).simulate(circuit, initial_state=initial_state)
np.testing.assert_allclose(result.final_state_vector, initial_state)


@pytest.mark.parametrize('n', [5, 32])
def test_approx_cswap_t_count(n):
cswap = CSwapApprox(bitsize=n)
Expand Down
2 changes: 1 addition & 1 deletion qualtran/bloqs/swap_network/swap_with_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def build_composite_bloq(

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
num_swaps = prod(x for x in self.n_target_registers) - 1
return {(CSwapApprox(self.target_bitsize), num_swaps)}
return {(self.cswap_n, num_swaps)}

def _circuit_diagram_info_(self, args) -> cirq.CircuitDiagramInfo:
from qualtran.cirq_interop._bloq_to_cirq import _wire_symbol_to_cirq_diagram_info
Expand Down
1 change: 0 additions & 1 deletion qualtran/cirq_interop/_bloq_to_cirq.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,6 @@ def _bloq_to_cirq_op(
del qvar_to_qreg[soq]

op, out_quregs = bloq.as_cirq_op(qubit_manager=qubit_manager, **in_quregs)

# 2. Update the mappings based on output soquets and `out_quregs`.
for cxn in succ_cxns:
soq = cxn.left
Expand Down
2 changes: 1 addition & 1 deletion qualtran/cirq_interop/_bloq_to_cirq_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def test_bloq_as_cirq_gate_left_register():
bb.free(q)
cbloq = bb.finalize()
circuit = cbloq.to_cirq_circuit()
cirq.testing.assert_has_diagram(circuit, """_c(0): ───alloc───X───free───""")
cirq.testing.assert_has_diagram(circuit, """_c(0): ───X───""")


def test_bloq_as_cirq_gate_for_mod_exp():
Expand Down
24 changes: 22 additions & 2 deletions qualtran/cirq_interop/_interop_qubit_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,30 @@ def __init__(self, qm: Optional[cirq.QubitManager] = None):
self._managed_qubits: Set[cirq.Qid] = set()

def qalloc(self, n: int, dim: int = 2) -> List['cirq.Qid']:
return self._qm.qalloc(n, dim)
ret: List['cirq.Qid'] = []
qubits_to_free: List['cirq.Qid'] = []
while len(ret) < n:
new_alloc = self._qm.qalloc(n - len(ret), dim)
for q in new_alloc:
if q in self._managed_qubits:
qubits_to_free.append(q)
else:
ret.append(q)
self._qm.qfree(qubits_to_free)
return ret

def qborrow(self, n: int, dim: int = 2) -> List['cirq.Qid']:
return self._qm.qborrow(n, dim)
ret: List['cirq.Qid'] = []
qubits_to_free: List['cirq.Qid'] = []
while len(ret) < n:
new_alloc = self._qm.qborrow(n - len(ret), dim)
for q in new_alloc:
if q in self._managed_qubits:
qubits_to_free.append(q)
else:
ret.append(q)
self._qm.qfree(qubits_to_free)
return ret

def manage_qubits(self, qubits: Iterable[cirq.Qid]):
self._managed_qubits |= set(qubits)
Expand Down
Loading