From eddb2d9cbdf55576c6e7532e0a25b40995d889dd Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Tue, 5 Sep 2023 19:08:32 -0700 Subject: [PATCH 01/19] Add caching to `value_equality_values` decorator for auto generated methods. (#6275) * Add caching to value_equality_values decorator for auto generated methods. * Fix pylint and formatting errors * Address nits, fix bugs and make PauliSum unhashable --- cirq-core/cirq/ops/dense_pauli_string.py | 3 +++ cirq-core/cirq/ops/linear_combinations.py | 2 +- cirq-core/cirq/value/value_equality_attr.py | 14 +++++++++++--- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/cirq-core/cirq/ops/dense_pauli_string.py b/cirq-core/cirq/ops/dense_pauli_string.py index 6cf97c4eb31..9893b64f706 100644 --- a/cirq-core/cirq/ops/dense_pauli_string.py +++ b/cirq-core/cirq/ops/dense_pauli_string.py @@ -570,6 +570,9 @@ def copy( def __str__(self) -> str: return super().__str__() + ' (mutable)' + def _value_equality_values_(self): + return self.coefficient, tuple(PAULI_CHARS[p] for p in self.pauli_mask) + @classmethod def inline_gaussian_elimination(cls, rows: 'List[MutableDensePauliString]') -> None: if not rows: diff --git a/cirq-core/cirq/ops/linear_combinations.py b/cirq-core/cirq/ops/linear_combinations.py index 9f50216dab9..ed223dfe0de 100644 --- a/cirq-core/cirq/ops/linear_combinations.py +++ b/cirq-core/cirq/ops/linear_combinations.py @@ -357,7 +357,7 @@ def _pauli_string_from_unit(unit: UnitPauliStringT, coefficient: Union[int, floa return PauliString(qubit_pauli_map=dict(unit), coefficient=coefficient) -@value.value_equality(approximate=True) +@value.value_equality(approximate=True, unhashable=True) class PauliSum: """Represents operator defined by linear combination of PauliStrings. diff --git a/cirq-core/cirq/value/value_equality_attr.py b/cirq-core/cirq/value/value_equality_attr.py index 31d570430a6..f66c6549e57 100644 --- a/cirq-core/cirq/value/value_equality_attr.py +++ b/cirq-core/cirq/value/value_equality_attr.py @@ -17,7 +17,7 @@ from typing_extensions import Protocol -from cirq import protocols +from cirq import protocols, _compat class _SupportsValueEquality(Protocol): @@ -221,13 +221,21 @@ class return the existing class' type. ) else: setattr(cls, '_value_equality_values_cls_', lambda self: cls) - setattr(cls, '__hash__', None if unhashable else _value_equality_hash) + cached_values_getter = values_getter if unhashable else _compat.cached_method(values_getter) + setattr(cls, '_value_equality_values_', cached_values_getter) + setattr(cls, '__hash__', None if unhashable else _compat.cached_method(_value_equality_hash)) setattr(cls, '__eq__', _value_equality_eq) setattr(cls, '__ne__', _value_equality_ne) if approximate: if not hasattr(cls, '_value_equality_approximate_values_'): - setattr(cls, '_value_equality_approximate_values_', values_getter) + setattr(cls, '_value_equality_approximate_values_', cached_values_getter) + else: + approx_values_getter = getattr(cls, '_value_equality_approximate_values_') + cached_approx_values_getter = ( + approx_values_getter if unhashable else _compat.cached_method(approx_values_getter) + ) + setattr(cls, '_value_equality_approximate_values_', cached_approx_values_getter) setattr(cls, '_approx_eq_', _value_equality_approx_eq) return cls From e2356428df0c41e6c4cdd4418254d2edc90390a0 Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Wed, 6 Sep 2023 13:39:11 -0700 Subject: [PATCH 02/19] Variable spacing QROM should depend upon structure of different data sequences and not exact elements (#6280) --- cirq-ft/cirq_ft/algos/qrom.py | 6 +----- cirq-ft/cirq_ft/algos/qrom_test.py | 4 +++- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/cirq-ft/cirq_ft/algos/qrom.py b/cirq-ft/cirq_ft/algos/qrom.py index 2bef77aab05..fdbe36792cc 100644 --- a/cirq-ft/cirq_ft/algos/qrom.py +++ b/cirq-ft/cirq_ft/algos/qrom.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Sequence, Tuple, Set +from typing import Callable, Sequence, Tuple import attr import cirq @@ -166,14 +166,10 @@ def decompose_zero_selection( context.qubit_manager.qfree(and_ancilla + [and_target]) def _break_early(self, selection_index_prefix: Tuple[int, ...], l: int, r: int): - global_unique_element: Set[int] = set() for data in self.data: unique_element = np.unique(data[selection_index_prefix][l:r]) if len(unique_element) > 1: return False - global_unique_element.add(unique_element[0]) - if len(global_unique_element) > 1: - return False return True def nth_operation( diff --git a/cirq-ft/cirq_ft/algos/qrom_test.py b/cirq-ft/cirq_ft/algos/qrom_test.py index 39b89963d8b..01025ac38c5 100644 --- a/cirq-ft/cirq_ft/algos/qrom_test.py +++ b/cirq-ft/cirq_ft/algos/qrom_test.py @@ -141,8 +141,10 @@ def test_qrom_variable_spacing(): assert cirq_ft.t_complexity(cirq_ft.QROM.build(data)).t == (8 - 2) * 4 # Works as expected when multiple data arrays are to be loaded. data = [1, 2, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5] + # (a) Both data sequences are identical assert cirq_ft.t_complexity(cirq_ft.QROM.build(data, data)).t == (5 - 2) * 4 - assert cirq_ft.t_complexity(cirq_ft.QROM.build(data, 2 * np.array(data))).t == (16 - 2) * 4 + # (b) Both data sequences have identical structure, even though the elements are not same. + assert cirq_ft.t_complexity(cirq_ft.QROM.build(data, 2 * np.array(data))).t == (5 - 2) * 4 # Works as expected when multidimensional input data is to be loaded qrom = cirq_ft.QROM.build( np.array( From 6fae409701d2b6ec84eea85cb791ce89e2bf588d Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Wed, 6 Sep 2023 14:54:25 -0700 Subject: [PATCH 03/19] Delete `SelectionRegisters` and replace uses of Registers with `Tuple[Register, ...]` (#6278) * Delete SelectionRegisters and replace uses of Registers with Tuple[Register, ...] * Add type ignore to fix mypy error * Address Matt's comments --- cirq-ft/cirq_ft/__init__.py | 1 - cirq-ft/cirq_ft/algos/and_gate.ipynb | 6 +- cirq-ft/cirq_ft/algos/and_gate.py | 4 +- cirq-ft/cirq_ft/algos/and_gate_test.py | 7 +- .../algos/apply_gate_to_lth_target.ipynb | 4 +- .../cirq_ft/algos/apply_gate_to_lth_target.py | 48 ++-- .../algos/apply_gate_to_lth_target_test.py | 20 +- cirq-ft/cirq_ft/algos/arithmetic_gates.py | 4 +- cirq-ft/cirq_ft/algos/generic_select.py | 21 +- cirq-ft/cirq_ft/algos/generic_select_test.py | 3 +- cirq-ft/cirq_ft/algos/hubbard_model.py | 89 +++---- cirq-ft/cirq_ft/algos/hubbard_model_test.py | 3 +- .../mean_estimation/complex_phase_oracle.py | 11 +- .../complex_phase_oracle_test.py | 15 +- .../mean_estimation_operator.py | 10 +- .../mean_estimation_operator_test.py | 44 ++-- .../algos/multi_control_multi_target_pauli.py | 2 +- .../phase_estimation_of_quantum_walk.ipynb | 3 +- .../algos/prepare_uniform_superposition.py | 4 +- .../prepare_uniform_superposition_test.py | 3 +- .../algos/programmable_rotation_gate_array.py | 22 +- .../programmable_rotation_gate_array_test.py | 8 +- cirq-ft/cirq_ft/algos/qrom.py | 40 ++- cirq-ft/cirq_ft/algos/qrom_test.py | 8 +- .../algos/qubitization_walk_operator.py | 14 +- .../algos/qubitization_walk_operator_test.py | 9 +- .../cirq_ft/algos/reflection_using_prepare.py | 15 +- .../algos/reflection_using_prepare_test.py | 9 +- cirq-ft/cirq_ft/algos/select_and_prepare.py | 13 +- cirq-ft/cirq_ft/algos/select_swap_qrom.py | 31 ++- .../cirq_ft/algos/select_swap_qrom_test.py | 7 +- .../algos/selected_majorana_fermion.py | 51 ++-- .../algos/selected_majorana_fermion_test.py | 27 +- cirq-ft/cirq_ft/algos/state_preparation.py | 28 ++- cirq-ft/cirq_ft/algos/swap_network.py | 12 +- cirq-ft/cirq_ft/algos/swap_network_test.py | 3 +- cirq-ft/cirq_ft/algos/unary_iteration.ipynb | 14 +- cirq-ft/cirq_ft/algos/unary_iteration_gate.py | 26 +- .../algos/unary_iteration_gate_test.py | 53 ++-- cirq-ft/cirq_ft/infra/__init__.py | 6 +- .../cirq_ft/infra/gate_with_registers.ipynb | 4 +- cirq-ft/cirq_ft/infra/gate_with_registers.py | 233 ++++++++---------- .../cirq_ft/infra/gate_with_registers_test.py | 38 +-- cirq-ft/cirq_ft/infra/jupyter_tools.py | 4 +- cirq-ft/cirq_ft/infra/t_complexity.ipynb | 12 +- .../infra/t_complexity_protocol_test.py | 5 +- cirq-ft/cirq_ft/infra/testing.py | 6 +- cirq-ft/cirq_ft/infra/type_convertors.py | 20 -- cirq-ft/cirq_ft/infra/type_convertors_test.py | 21 -- 49 files changed, 488 insertions(+), 553 deletions(-) delete mode 100644 cirq-ft/cirq_ft/infra/type_convertors.py delete mode 100644 cirq-ft/cirq_ft/infra/type_convertors_test.py diff --git a/cirq-ft/cirq_ft/__init__.py b/cirq-ft/cirq_ft/__init__.py index 00053e949b5..47bf47cf660 100644 --- a/cirq-ft/cirq_ft/__init__.py +++ b/cirq-ft/cirq_ft/__init__.py @@ -49,7 +49,6 @@ Register, Registers, SelectionRegister, - SelectionRegisters, TComplexity, map_clean_and_borrowable_qubits, t_complexity, diff --git a/cirq-ft/cirq_ft/algos/and_gate.ipynb b/cirq-ft/cirq_ft/algos/and_gate.ipynb index ca1562840ad..498081f0cb4 100644 --- a/cirq-ft/cirq_ft/algos/and_gate.ipynb +++ b/cirq-ft/cirq_ft/algos/and_gate.ipynb @@ -63,11 +63,11 @@ "source": [ "import cirq\n", "from cirq.contrib.svg import SVGCircuit\n", - "from cirq_ft import And\n", + "from cirq_ft import And, infra\n", "\n", "gate = And()\n", "r = gate.registers\n", - "quregs = r.get_named_qubits()\n", + "quregs = infra.get_named_qubits(r)\n", "operation = gate.on_registers(**quregs)\n", "circuit = cirq.Circuit(operation)\n", "SVGCircuit(circuit)" @@ -223,4 +223,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/cirq-ft/cirq_ft/algos/and_gate.py b/cirq-ft/cirq_ft/algos/and_gate.py index 973528386dc..f308926d632 100644 --- a/cirq-ft/cirq_ft/algos/and_gate.py +++ b/cirq-ft/cirq_ft/algos/and_gate.py @@ -49,7 +49,9 @@ class And(infra.GateWithRegisters): ValueError: If number of control values (i.e. `len(self.cv)`) is less than 2. """ - cv: Tuple[int, ...] = attr.field(default=(1, 1), converter=infra.to_tuple) + cv: Tuple[int, ...] = attr.field( + default=(1, 1), converter=lambda v: (v,) if isinstance(v, int) else tuple(v) + ) adjoint: bool = False @cv.validator diff --git a/cirq-ft/cirq_ft/algos/and_gate_test.py b/cirq-ft/cirq_ft/algos/and_gate_test.py index f41b6a271c1..70de51a205b 100644 --- a/cirq-ft/cirq_ft/algos/and_gate_test.py +++ b/cirq-ft/cirq_ft/algos/and_gate_test.py @@ -20,6 +20,7 @@ import cirq_ft import numpy as np import pytest +from cirq_ft import infra from cirq_ft.infra.jupyter_tools import execute_notebook random.seed(12345) @@ -46,12 +47,12 @@ def test_multi_controlled_and_gate(cv: List[int]): gate = cirq_ft.And(cv) r = gate.registers assert r['ancilla'].total_bits() == r['control'].total_bits() - 2 - quregs = r.get_named_qubits() + quregs = infra.get_named_qubits(r) and_op = gate.on_registers(**quregs) circuit = cirq.Circuit(and_op) input_controls = [cv] + [random_cv(len(cv)) for _ in range(10)] - qubit_order = gate.registers.merge_qubits(**quregs) + qubit_order = infra.merge_qubits(gate.registers, **quregs) for input_control in input_controls: initial_state = input_control + [0] * (r['ancilla'].total_bits() + 1) @@ -77,7 +78,7 @@ def test_multi_controlled_and_gate(cv: List[int]): def test_and_gate_diagram(): gate = cirq_ft.And((1, 0, 1, 0, 1, 0)) - qubit_regs = gate.registers.get_named_qubits() + qubit_regs = infra.get_named_qubits(gate.registers) op = gate.on_registers(**qubit_regs) # Qubit order should be alternating (control, ancilla) pairs. c_and_a = sum(zip(qubit_regs["control"][1:], qubit_regs["ancilla"]), ()) + ( diff --git a/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.ipynb b/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.ipynb index 90e4ac086fe..3e34704c619 100644 --- a/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.ipynb +++ b/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.ipynb @@ -69,7 +69,7 @@ "`selection`-th qubit of `target` all controlled by the `control` register.\n", "\n", "#### Parameters\n", - " - `selection_regs`: Indexing `select` registers of type `SelectionRegisters`. It also contains information about the iteration length of each selection register.\n", + " - `selection_regs`: Indexing `select` registers of type Tuple[`SelectionRegister`, ...]. It also contains information about the iteration length of each selection register.\n", " - `nth_gate`: A function mapping the composite selection index to a single-qubit gate.\n", " - `control_regs`: Control registers for constructing a controlled version of the gate.\n" ] @@ -89,7 +89,7 @@ " return cirq.I\n", "\n", "apply_z_to_odd = cirq_ft.ApplyGateToLthQubit(\n", - " cirq_ft.SelectionRegisters([cirq_ft.SelectionRegister('selection', 3, 4)]),\n", + " cirq_ft.SelectionRegister('selection', 3, 4),\n", " nth_gate=_z_to_odd,\n", " control_regs=cirq_ft.Registers.build(control=2),\n", ")\n", diff --git a/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.py b/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.py index e796b1b05f0..e3bb08be143 100644 --- a/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.py +++ b/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.py @@ -13,10 +13,11 @@ # limitations under the License. import itertools -from typing import Callable, Sequence +from typing import Callable, Sequence, Tuple import attr import cirq +import numpy as np from cirq._compat import cached_property from cirq_ft import infra from cirq_ft.algos import unary_iteration_gate @@ -36,8 +37,8 @@ class ApplyGateToLthQubit(unary_iteration_gate.UnaryIterationGate): `selection`-th qubit of `target` all controlled by the `control` register. Args: - selection_regs: Indexing `select` registers of type `SelectionRegisters`. It also contains - information about the iteration length of each selection register. + selection_regs: Indexing `select` registers of type Tuple[`SelectionRegisters`, ...]. + It also contains information about the iteration length of each selection register. nth_gate: A function mapping the composite selection index to a single-qubit gate. control_regs: Control registers for constructing a controlled version of the gate. @@ -46,43 +47,45 @@ class ApplyGateToLthQubit(unary_iteration_gate.UnaryIterationGate): (https://arxiv.org/abs/1805.03662). Babbush et. al. (2018). Section III.A. and Figure 7. """ - selection_regs: infra.SelectionRegisters + selection_regs: Tuple[infra.SelectionRegister, ...] = attr.field( + converter=lambda v: (v,) if isinstance(v, infra.SelectionRegister) else tuple(v) + ) nth_gate: Callable[..., cirq.Gate] - control_regs: infra.Registers = infra.Registers.build(control=1) + control_regs: Tuple[infra.Register, ...] = attr.field( + converter=lambda v: (v,) if isinstance(v, infra.Register) else tuple(v), + default=(infra.Register('control', 1),), + ) @classmethod def make_on( cls, *, nth_gate: Callable[..., cirq.Gate], **quregs: Sequence[cirq.Qid] ) -> cirq.Operation: """Helper constructor to automatically deduce bitsize attributes.""" - return cls( - infra.SelectionRegisters( - [ - infra.SelectionRegister( - 'selection', len(quregs['selection']), len(quregs['target']) - ) - ] - ), + return ApplyGateToLthQubit( + infra.SelectionRegister('selection', len(quregs['selection']), len(quregs['target'])), nth_gate=nth_gate, - control_regs=infra.Registers.build(control=len(quregs['control'])), + control_regs=infra.Register('control', len(quregs['control'])), ).on_registers(**quregs) @cached_property - def control_registers(self) -> infra.Registers: + def control_registers(self) -> Tuple[infra.Register, ...]: return self.control_regs @cached_property - def selection_registers(self) -> infra.SelectionRegisters: + def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: return self.selection_regs @cached_property - def target_registers(self) -> infra.Registers: - return infra.Registers.build(target=self.selection_registers.total_iteration_size) + def target_registers(self) -> Tuple[infra.Register, ...]: + total_iteration_size = np.product( + tuple(reg.iteration_length for reg in self.selection_registers) + ) + return (infra.Register('target', int(total_iteration_size)),) def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo: - wire_symbols = ["@"] * self.control_registers.total_bits() - wire_symbols += ["In"] * self.selection_registers.total_bits() - for it in itertools.product(*[range(x) for x in self.selection_regs.iteration_lengths]): + wire_symbols = ["@"] * infra.total_bits(self.control_registers) + wire_symbols += ["In"] * infra.total_bits(self.selection_registers) + for it in itertools.product(*[range(reg.iteration_length) for reg in self.selection_regs]): wire_symbols += [str(self.nth_gate(*it))] return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols) @@ -93,6 +96,7 @@ def nth_operation( # type: ignore[override] target: Sequence[cirq.Qid], **selection_indices: int, ) -> cirq.OP_TREE: + selection_shape = tuple(reg.iteration_length for reg in self.selection_regs) selection_idx = tuple(selection_indices[reg.name] for reg in self.selection_regs) - target_idx = self.selection_registers.to_flat_idx(*selection_idx) + target_idx = int(np.ravel_multi_index(selection_idx, selection_shape)) return self.nth_gate(*selection_idx).on(target[target_idx]).controlled_by(control) diff --git a/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target_test.py b/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target_test.py index da285792d36..2c2e29e7c0c 100644 --- a/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target_test.py +++ b/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target_test.py @@ -15,6 +15,7 @@ import cirq import cirq_ft import pytest +from cirq_ft import infra from cirq_ft.infra.bit_tools import iter_bits from cirq_ft.infra.jupyter_tools import execute_notebook @@ -23,16 +24,13 @@ def test_apply_gate_to_lth_qubit(selection_bitsize, target_bitsize): greedy_mm = cirq_ft.GreedyQubitManager(prefix="_a", maximize_reuse=True) gate = cirq_ft.ApplyGateToLthQubit( - cirq_ft.SelectionRegisters( - [cirq_ft.SelectionRegister('selection', selection_bitsize, target_bitsize)] - ), - lambda _: cirq.X, + cirq_ft.SelectionRegister('selection', selection_bitsize, target_bitsize), lambda _: cirq.X ) g = cirq_ft.testing.GateHelper(gate, context=cirq.DecompositionContext(greedy_mm)) # Upper bounded because not all ancillas may be used as part of unary iteration. assert ( len(g.all_qubits) - <= target_bitsize + 2 * (selection_bitsize + gate.control_registers.total_bits()) - 1 + <= target_bitsize + 2 * (selection_bitsize + infra.total_bits(gate.control_registers)) - 1 ) for n in range(target_bitsize): @@ -54,12 +52,12 @@ def test_apply_gate_to_lth_qubit(selection_bitsize, target_bitsize): def test_apply_gate_to_lth_qubit_diagram(): # Apply Z gate to all odd targets and Identity to even targets. gate = cirq_ft.ApplyGateToLthQubit( - cirq_ft.SelectionRegisters([cirq_ft.SelectionRegister('selection', 3, 5)]), + cirq_ft.SelectionRegister('selection', 3, 5), lambda n: cirq.Z if n & 1 else cirq.I, control_regs=cirq_ft.Registers.build(control=2), ) - circuit = cirq.Circuit(gate.on_registers(**gate.registers.get_named_qubits())) - qubits = list(q for v in gate.registers.get_named_qubits().values() for q in v) + circuit = cirq.Circuit(gate.on_registers(**infra.get_named_qubits(gate.registers))) + qubits = list(q for v in infra.get_named_qubits(gate.registers).values() for q in v) cirq.testing.assert_has_diagram( circuit, """ @@ -89,13 +87,13 @@ def test_apply_gate_to_lth_qubit_diagram(): def test_apply_gate_to_lth_qubit_make_on(): gate = cirq_ft.ApplyGateToLthQubit( - cirq_ft.SelectionRegisters([cirq_ft.SelectionRegister('selection', 3, 5)]), + cirq_ft.SelectionRegister('selection', 3, 5), lambda n: cirq.Z if n & 1 else cirq.I, control_regs=cirq_ft.Registers.build(control=2), ) - op = gate.on_registers(**gate.registers.get_named_qubits()) + op = gate.on_registers(**infra.get_named_qubits(gate.registers)) op2 = cirq_ft.ApplyGateToLthQubit.make_on( - nth_gate=lambda n: cirq.Z if n & 1 else cirq.I, **gate.registers.get_named_qubits() + nth_gate=lambda n: cirq.Z if n & 1 else cirq.I, **infra.get_named_qubits(gate.registers) ) # Note: ApplyGateToLthQubit doesn't support value equality. assert op.qubits == op2.qubits diff --git a/cirq-ft/cirq_ft/algos/arithmetic_gates.py b/cirq-ft/cirq_ft/algos/arithmetic_gates.py index b75383015e3..6054c90709f 100644 --- a/cirq-ft/cirq_ft/algos/arithmetic_gates.py +++ b/cirq-ft/cirq_ft/algos/arithmetic_gates.py @@ -611,7 +611,9 @@ class AddMod(cirq.ArithmeticGate): bitsize: int mod: int = attr.field() add_val: int = 1 - cv: Tuple[int, ...] = attr.field(converter=infra.to_tuple, default=()) + cv: Tuple[int, ...] = attr.field( + converter=lambda v: (v,) if isinstance(v, int) else tuple(v), default=() + ) @mod.validator def _validate_mod(self, attribute, value): diff --git a/cirq-ft/cirq_ft/algos/generic_select.py b/cirq-ft/cirq_ft/algos/generic_select.py index 68d62cf98f6..8822beb32f6 100644 --- a/cirq-ft/cirq_ft/algos/generic_select.py +++ b/cirq-ft/cirq_ft/algos/generic_select.py @@ -68,23 +68,20 @@ def __attrs_post_init__(self): ) @cached_property - def control_registers(self) -> infra.Registers: - registers = [] if self.control_val is None else [infra.Register('control', 1)] - return infra.Registers(registers) + def control_registers(self) -> Tuple[infra.Register, ...]: + return () if self.control_val is None else (infra.Register('control', 1),) @cached_property - def selection_registers(self) -> infra.SelectionRegisters: - return infra.SelectionRegisters( - [ - infra.SelectionRegister( - 'selection', self.selection_bitsize, len(self.select_unitaries) - ) - ] + def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: + return ( + infra.SelectionRegister( + 'selection', self.selection_bitsize, len(self.select_unitaries) + ), ) @cached_property - def target_registers(self) -> infra.Registers: - return infra.Registers.build(target=self.target_bitsize) + def target_registers(self) -> Tuple[infra.Register, ...]: + return (infra.Register('target', self.target_bitsize),) def decompose_from_registers( self, context, **quregs: NDArray[cirq.Qid] # type:ignore[type-var] diff --git a/cirq-ft/cirq_ft/algos/generic_select_test.py b/cirq-ft/cirq_ft/algos/generic_select_test.py index c074f8d3197..255e9ba6b79 100644 --- a/cirq-ft/cirq_ft/algos/generic_select_test.py +++ b/cirq-ft/cirq_ft/algos/generic_select_test.py @@ -17,6 +17,7 @@ import cirq_ft import numpy as np import pytest +from cirq_ft import infra from cirq_ft.infra.bit_tools import iter_bits from cirq_ft.infra.jupyter_tools import execute_notebook @@ -255,7 +256,7 @@ def test_generic_select_consistent_protocols_and_controlled(): # Build GenericSelect gate. gate = cirq_ft.GenericSelect(select_bitsize, num_sites, dps_hamiltonian) - op = gate.on_registers(**gate.registers.get_named_qubits()) + op = gate.on_registers(**infra.get_named_qubits(gate.registers)) cirq.testing.assert_equivalent_repr(gate, setup_code='import cirq\nimport cirq_ft') # Build controlled gate diff --git a/cirq-ft/cirq_ft/algos/hubbard_model.py b/cirq-ft/cirq_ft/algos/hubbard_model.py index 520d305062a..dad2a443c9a 100644 --- a/cirq-ft/cirq_ft/algos/hubbard_model.py +++ b/cirq-ft/cirq_ft/algos/hubbard_model.py @@ -118,28 +118,25 @@ def __attrs_post_init__(self): raise NotImplementedError("Currently only supports the case where x_dim=y_dim.") @cached_property - def control_registers(self) -> infra.Registers: - registers = [] if self.control_val is None else [infra.Register('control', 1)] - return infra.Registers(registers) + def control_registers(self) -> Tuple[infra.Register, ...]: + return () if self.control_val is None else (infra.Register('control', 1),) @cached_property - def selection_registers(self) -> infra.SelectionRegisters: - return infra.SelectionRegisters( - [ - infra.SelectionRegister('U', 1, 2), - infra.SelectionRegister('V', 1, 2), - infra.SelectionRegister('p_x', (self.x_dim - 1).bit_length(), self.x_dim), - infra.SelectionRegister('p_y', (self.y_dim - 1).bit_length(), self.y_dim), - infra.SelectionRegister('alpha', 1, 2), - infra.SelectionRegister('q_x', (self.x_dim - 1).bit_length(), self.x_dim), - infra.SelectionRegister('q_y', (self.y_dim - 1).bit_length(), self.y_dim), - infra.SelectionRegister('beta', 1, 2), - ] + def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: + return ( + infra.SelectionRegister('U', 1, 2), + infra.SelectionRegister('V', 1, 2), + infra.SelectionRegister('p_x', (self.x_dim - 1).bit_length(), self.x_dim), + infra.SelectionRegister('p_y', (self.y_dim - 1).bit_length(), self.y_dim), + infra.SelectionRegister('alpha', 1, 2), + infra.SelectionRegister('q_x', (self.x_dim - 1).bit_length(), self.x_dim), + infra.SelectionRegister('q_y', (self.y_dim - 1).bit_length(), self.y_dim), + infra.SelectionRegister('beta', 1, 2), ) @cached_property - def target_registers(self) -> infra.Registers: - return infra.Registers.build(target=self.x_dim * self.y_dim * 2) + def target_registers(self) -> Tuple[infra.Register, ...]: + return (infra.Register('target', self.x_dim * self.y_dim * 2),) @cached_property def registers(self) -> infra.Registers: @@ -158,12 +155,10 @@ def decompose_from_registers( control, target = quregs.get('control', ()), quregs['target'] yield selected_majorana_fermion.SelectedMajoranaFermionGate( - selection_regs=infra.SelectionRegisters( - [ - infra.SelectionRegister('alpha', 1, 2), - infra.SelectionRegister('p_y', self.registers['p_y'].total_bits(), self.y_dim), - infra.SelectionRegister('p_x', self.registers['p_x'].total_bits(), self.x_dim), - ] + selection_regs=( + infra.SelectionRegister('alpha', 1, 2), + infra.SelectionRegister('p_y', self.registers['p_y'].total_bits(), self.y_dim), + infra.SelectionRegister('p_x', self.registers['p_x'].total_bits(), self.x_dim), ), control_regs=self.control_registers, target_gate=cirq.Y, @@ -173,12 +168,10 @@ def decompose_from_registers( yield swap_network.MultiTargetCSwap.make_on(control=V, target_x=p_y, target_y=q_y) yield swap_network.MultiTargetCSwap.make_on(control=V, target_x=alpha, target_y=beta) - q_selection_regs = infra.SelectionRegisters( - [ - infra.SelectionRegister('beta', 1, 2), - infra.SelectionRegister('q_y', self.registers['q_y'].total_bits(), self.y_dim), - infra.SelectionRegister('q_x', self.registers['q_x'].total_bits(), self.x_dim), - ] + q_selection_regs = ( + infra.SelectionRegister('beta', 1, 2), + infra.SelectionRegister('q_y', self.registers['q_y'].total_bits(), self.y_dim), + infra.SelectionRegister('q_x', self.registers['q_x'].total_bits(), self.x_dim), ) yield selected_majorana_fermion.SelectedMajoranaFermionGate( selection_regs=q_selection_regs, control_regs=self.control_registers, target_gate=cirq.X @@ -194,20 +187,18 @@ def decompose_from_registers( yield cirq.Z(*U).controlled_by(*control) # Fix errant -1 from multiple pauli applications target_qubits_for_apply_to_lth_gate = [ - target[q_selection_regs.to_flat_idx(1, qy, qx)] + target[np.ravel_multi_index((1, qy, qx), (2, self.y_dim, self.x_dim))] for qx in range(self.x_dim) for qy in range(self.y_dim) ] yield apply_gate_to_lth_target.ApplyGateToLthQubit( - selection_regs=infra.SelectionRegisters( - [ - infra.SelectionRegister('q_y', self.registers['q_y'].total_bits(), self.y_dim), - infra.SelectionRegister('q_x', self.registers['q_x'].total_bits(), self.x_dim), - ] + selection_regs=( + infra.SelectionRegister('q_y', self.registers['q_y'].total_bits(), self.y_dim), + infra.SelectionRegister('q_x', self.registers['q_x'].total_bits(), self.x_dim), ), nth_gate=lambda *_: cirq.Z, - control_regs=infra.Registers.build(control=1 + self.control_registers.total_bits()), + control_regs=infra.Register('control', 1 + infra.total_bits(self.control_registers)), ).on_registers( q_x=q_x, q_y=q_y, control=[*V, *control], target=target_qubits_for_apply_to_lth_gate ) @@ -291,23 +282,21 @@ def __attrs_post_init__(self): raise NotImplementedError("Currently only supports the case where x_dim=y_dim.") @cached_property - def selection_registers(self) -> infra.SelectionRegisters: - return infra.SelectionRegisters( - [ - infra.SelectionRegister('U', 1, 2), - infra.SelectionRegister('V', 1, 2), - infra.SelectionRegister('p_x', (self.x_dim - 1).bit_length(), self.x_dim), - infra.SelectionRegister('p_y', (self.y_dim - 1).bit_length(), self.y_dim), - infra.SelectionRegister('alpha', 1, 2), - infra.SelectionRegister('q_x', (self.x_dim - 1).bit_length(), self.x_dim), - infra.SelectionRegister('q_y', (self.y_dim - 1).bit_length(), self.y_dim), - infra.SelectionRegister('beta', 1, 2), - ] + def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: + return ( + infra.SelectionRegister('U', 1, 2), + infra.SelectionRegister('V', 1, 2), + infra.SelectionRegister('p_x', (self.x_dim - 1).bit_length(), self.x_dim), + infra.SelectionRegister('p_y', (self.y_dim - 1).bit_length(), self.y_dim), + infra.SelectionRegister('alpha', 1, 2), + infra.SelectionRegister('q_x', (self.x_dim - 1).bit_length(), self.x_dim), + infra.SelectionRegister('q_y', (self.y_dim - 1).bit_length(), self.y_dim), + infra.SelectionRegister('beta', 1, 2), ) @cached_property - def junk_registers(self) -> infra.Registers: - return infra.Registers.build(temp=2) + def junk_registers(self) -> Tuple[infra.Register, ...]: + return (infra.Register('temp', 2),) @cached_property def registers(self) -> infra.Registers: diff --git a/cirq-ft/cirq_ft/algos/hubbard_model_test.py b/cirq-ft/cirq_ft/algos/hubbard_model_test.py index 43c66f9bf0f..b13f9e6dfd6 100644 --- a/cirq-ft/cirq_ft/algos/hubbard_model_test.py +++ b/cirq-ft/cirq_ft/algos/hubbard_model_test.py @@ -15,6 +15,7 @@ import cirq import cirq_ft import pytest +from cirq_ft import infra from cirq_ft.infra.jupyter_tools import execute_notebook @@ -48,7 +49,7 @@ def test_hubbard_model_consistent_protocols(): cirq.testing.assert_equivalent_repr(prepare_gate, setup_code='import cirq_ft') # Build controlled SELECT gate - select_op = select_gate.on_registers(**select_gate.registers.get_named_qubits()) + select_op = select_gate.on_registers(**infra.get_named_qubits(select_gate.registers)) equals_tester = cirq.testing.EqualsTester() equals_tester.add_equality_group( select_gate.controlled(), diff --git a/cirq-ft/cirq_ft/algos/mean_estimation/complex_phase_oracle.py b/cirq-ft/cirq_ft/algos/mean_estimation/complex_phase_oracle.py index 74f992b0f5d..695a51ba854 100644 --- a/cirq-ft/cirq_ft/algos/mean_estimation/complex_phase_oracle.py +++ b/cirq-ft/cirq_ft/algos/mean_estimation/complex_phase_oracle.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Tuple from numpy.typing import NDArray import attr @@ -37,11 +38,11 @@ class ComplexPhaseOracle(infra.GateWithRegisters): arctan_bitsize: int = 32 @cached_property - def control_registers(self) -> infra.Registers: + def control_registers(self) -> Tuple[infra.Register, ...]: return self.encoder.control_registers @cached_property - def selection_registers(self) -> infra.SelectionRegisters: + def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: return self.encoder.selection_registers @cached_property @@ -58,7 +59,7 @@ def decompose_from_registers( target_reg = { reg.name: qm.qalloc(reg.total_bits()) for reg in self.encoder.target_registers } - target_qubits = self.encoder.target_registers.merge_qubits(**target_reg) + target_qubits = infra.merge_qubits(self.encoder.target_registers, **target_reg) encoder_op = self.encoder.on_registers(**quregs, **target_reg) arctan_sign, arctan_target = qm.qalloc(1), qm.qalloc(self.arctan_bitsize) @@ -78,6 +79,6 @@ def decompose_from_registers( qm.qfree([*arctan_sign, *arctan_target, *target_qubits]) def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo: - wire_symbols = ['@'] * self.control_registers.total_bits() - wire_symbols += ['ROTy'] * self.selection_registers.total_bits() + wire_symbols = ['@'] * infra.total_bits(self.control_registers) + wire_symbols += ['ROTy'] * infra.total_bits(self.selection_registers) return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols) diff --git a/cirq-ft/cirq_ft/algos/mean_estimation/complex_phase_oracle_test.py b/cirq-ft/cirq_ft/algos/mean_estimation/complex_phase_oracle_test.py index 7e6b61527c2..a7926bf3847 100644 --- a/cirq-ft/cirq_ft/algos/mean_estimation/complex_phase_oracle_test.py +++ b/cirq-ft/cirq_ft/algos/mean_estimation/complex_phase_oracle_test.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import Optional +from typing import Optional, Tuple import cirq import cirq_ft @@ -32,17 +32,16 @@ class DummySelect(cirq_ft.SelectOracle): control_val: Optional[int] = None @cached_property - def control_registers(self) -> cirq_ft.Registers: - registers = [] if self.control_val is None else [cirq_ft.Register('control', 1)] - return cirq_ft.Registers(registers) + def control_registers(self) -> Tuple[cirq_ft.Register, ...]: + return () if self.control_val is None else (cirq_ft.Register('control', 1),) @cached_property - def selection_registers(self) -> cirq_ft.SelectionRegisters: - return cirq_ft.SelectionRegisters.build(selection=self.bitsize) + def selection_registers(self) -> Tuple[cirq_ft.SelectionRegister, ...]: + return (cirq_ft.SelectionRegister('selection', self.bitsize),) @cached_property - def target_registers(self) -> cirq_ft.Registers: - return cirq_ft.Registers.build(target=self.bitsize) + def target_registers(self) -> Tuple[cirq_ft.Register, ...]: + return (cirq_ft.Register('target', self.bitsize),) def decompose_from_registers(self, context, selection, target): yield [cirq.CNOT(s, t) for s, t in zip(selection, target)] diff --git a/cirq-ft/cirq_ft/algos/mean_estimation/mean_estimation_operator.py b/cirq-ft/cirq_ft/algos/mean_estimation/mean_estimation_operator.py index f3fff37ecd7..40de332dad1 100644 --- a/cirq-ft/cirq_ft/algos/mean_estimation/mean_estimation_operator.py +++ b/cirq-ft/cirq_ft/algos/mean_estimation/mean_estimation_operator.py @@ -80,7 +80,9 @@ class MeanEstimationOperator(infra.GateWithRegisters): """ code: CodeForRandomVariable - cv: Tuple[int, ...] = attr.field(converter=infra.to_tuple, default=()) + cv: Tuple[int, ...] = attr.field( + converter=lambda v: (v,) if isinstance(v, int) else tuple(v), default=() + ) power: int = 1 arctan_bitsize: int = 32 @@ -99,11 +101,11 @@ def select(self) -> complex_phase_oracle.ComplexPhaseOracle: return complex_phase_oracle.ComplexPhaseOracle(self.code.encoder, self.arctan_bitsize) @cached_property - def control_registers(self) -> infra.Registers: + def control_registers(self) -> Tuple[infra.Register, ...]: return self.code.encoder.control_registers @cached_property - def selection_registers(self) -> infra.SelectionRegisters: + def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: return self.code.encoder.selection_registers @cached_property @@ -130,7 +132,7 @@ def decompose_from_registers( def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo: wire_symbols = [] if self.cv == () else [["@(0)", "@"][self.cv[0]]] wire_symbols += ['U_ko'] * ( - self.registers.total_bits() - self.control_registers.total_bits() + infra.total_bits(self.registers) - infra.total_bits(self.control_registers) ) if self.power != 1: wire_symbols[-1] = f'U_ko^{self.power}' diff --git a/cirq-ft/cirq_ft/algos/mean_estimation/mean_estimation_operator_test.py b/cirq-ft/cirq_ft/algos/mean_estimation/mean_estimation_operator_test.py index f9f7c359165..ef9596dd861 100644 --- a/cirq-ft/cirq_ft/algos/mean_estimation/mean_estimation_operator_test.py +++ b/cirq-ft/cirq_ft/algos/mean_estimation/mean_estimation_operator_test.py @@ -19,6 +19,7 @@ import numpy as np import pytest from attr import frozen +from cirq_ft import infra from cirq._compat import cached_property from cirq_ft.algos.mean_estimation import CodeForRandomVariable, MeanEstimationOperator from cirq_ft.infra import bit_tools @@ -32,8 +33,8 @@ class BernoulliSynthesizer(cirq_ft.PrepareOracle): nqubits: int @cached_property - def selection_registers(self) -> cirq_ft.SelectionRegisters: - return cirq_ft.SelectionRegisters([cirq_ft.SelectionRegister('q', self.nqubits, 2)]) + def selection_registers(self) -> Tuple[cirq_ft.SelectionRegister, ...]: + return (cirq_ft.SelectionRegister('q', self.nqubits, 2),) def decompose_from_registers( # type:ignore[override] self, context, q: Sequence[cirq.Qid] @@ -54,19 +55,16 @@ class BernoulliEncoder(cirq_ft.SelectOracle): control_val: Optional[int] = None @cached_property - def control_registers(self) -> cirq_ft.Registers: - registers = [] if self.control_val is None else [cirq_ft.Register('control', 1)] - return cirq_ft.Registers(registers) + def control_registers(self) -> Tuple[cirq_ft.Register, ...]: + return () if self.control_val is None else (cirq_ft.Register('control', 1),) @cached_property - def selection_registers(self) -> cirq_ft.SelectionRegisters: - return cirq_ft.SelectionRegisters( - [cirq_ft.SelectionRegister('q', self.selection_bitsize, 2)] - ) + def selection_registers(self) -> Tuple[cirq_ft.SelectionRegister, ...]: + return (cirq_ft.SelectionRegister('q', self.selection_bitsize, 2),) @cached_property - def target_registers(self) -> cirq_ft.Registers: - return cirq_ft.Registers.build(t=self.target_bitsize) + def target_registers(self) -> Tuple[cirq_ft.Register, ...]: + return (cirq_ft.Register('t', self.target_bitsize),) def decompose_from_registers( # type:ignore[override] self, context, q: Sequence[cirq.Qid], t: Sequence[cirq.Qid] @@ -119,7 +117,7 @@ def satisfies_theorem_321( assert cirq.is_unitary(u) # Compute the final state vector obtained using the synthesizer `Prep |0>` - prep_op = synthesizer.on_registers(**synthesizer.registers.get_named_qubits()) + prep_op = synthesizer.on_registers(**infra.get_named_qubits(synthesizer.registers)) prep_state = cirq.Circuit(prep_op).final_state_vector() expected_hav = abs(mu) * np.sqrt(1 / (1 + s**2)) @@ -174,8 +172,8 @@ class GroverSynthesizer(cirq_ft.PrepareOracle): n: int @cached_property - def selection_registers(self) -> cirq_ft.SelectionRegisters: - return cirq_ft.SelectionRegisters.build(selection=self.n) + def selection_registers(self) -> Tuple[cirq_ft.SelectionRegister, ...]: + return (cirq_ft.SelectionRegister('selection', self.n),) def decompose_from_registers( # type:ignore[override] self, *, context, selection: Sequence[cirq.Qid] @@ -197,24 +195,24 @@ class GroverEncoder(cirq_ft.SelectOracle): marked_val: int @cached_property - def control_registers(self) -> cirq_ft.Registers: - return cirq_ft.Registers([]) + def control_registers(self) -> Tuple[cirq_ft.Register, ...]: + return () @cached_property - def selection_registers(self) -> cirq_ft.SelectionRegisters: - return cirq_ft.SelectionRegisters.build(selection=self.n) + def selection_registers(self) -> Tuple[cirq_ft.SelectionRegister, ...]: + return (cirq_ft.SelectionRegister('selection', self.n),) @cached_property - def target_registers(self) -> cirq_ft.Registers: - return cirq_ft.Registers.build(target=self.marked_val.bit_length()) + def target_registers(self) -> Tuple[cirq_ft.Register, ...]: + return (cirq_ft.Register('target', self.marked_val.bit_length()),) def decompose_from_registers( # type:ignore[override] self, context, *, selection: Sequence[cirq.Qid], target: Sequence[cirq.Qid] ) -> cirq.OP_TREE: selection_cv = [ - *bit_tools.iter_bits(self.marked_item, self.selection_registers.total_bits()) + *bit_tools.iter_bits(self.marked_item, infra.total_bits(self.selection_registers)) ] - yval_bin = [*bit_tools.iter_bits(self.marked_val, self.target_registers.total_bits())] + yval_bin = [*bit_tools.iter_bits(self.marked_val, infra.total_bits(self.target_registers))] for b, q in zip(yval_bin, target): if b: @@ -254,7 +252,7 @@ def test_mean_estimation_operator_consistent_protocols(): encoder = BernoulliEncoder(p, (0, y_1), selection_bitsize, target_bitsize) code = CodeForRandomVariable(synthesizer=synthesizer, encoder=encoder) mean_gate = MeanEstimationOperator(code, arctan_bitsize=arctan_bitsize) - op = mean_gate.on_registers(**mean_gate.registers.get_named_qubits()) + op = mean_gate.on_registers(**infra.get_named_qubits(mean_gate.registers)) # Test controlled gate. equals_tester = cirq.testing.EqualsTester() diff --git a/cirq-ft/cirq_ft/algos/multi_control_multi_target_pauli.py b/cirq-ft/cirq_ft/algos/multi_control_multi_target_pauli.py index 6ab7e65e51f..bb96215b729 100644 --- a/cirq-ft/cirq_ft/algos/multi_control_multi_target_pauli.py +++ b/cirq-ft/cirq_ft/algos/multi_control_multi_target_pauli.py @@ -73,7 +73,7 @@ class MultiControlPauli(infra.GateWithRegisters): (https://algassert.com/circuits/2015/06/05/Constructing-Large-Controlled-Nots.html) """ - cvs: Tuple[int, ...] = attr.field(converter=infra.to_tuple) + cvs: Tuple[int, ...] = attr.field(converter=lambda v: (v,) if isinstance(v, int) else tuple(v)) target_gate: cirq.Pauli = cirq.X @cached_property diff --git a/cirq-ft/cirq_ft/algos/phase_estimation_of_quantum_walk.ipynb b/cirq-ft/cirq_ft/algos/phase_estimation_of_quantum_walk.ipynb index 206b4466b32..6121248b9fa 100644 --- a/cirq-ft/cirq_ft/algos/phase_estimation_of_quantum_walk.ipynb +++ b/cirq-ft/cirq_ft/algos/phase_estimation_of_quantum_walk.ipynb @@ -39,6 +39,7 @@ "import numpy as np\n", "\n", "import cirq_ft\n", + "from cirq_ft import infra\n", "\n", "from cirq_ft.algos.qubitization_walk_operator_test import get_walk_operator_for_1d_Ising_model\n", "from cirq_ft.algos.hubbard_model import get_walk_operator_for_hubbard_model" @@ -87,7 +88,7 @@ " Fig. 2\n", " \"\"\"\n", " reflect = walk.reflect\n", - " walk_regs = walk.registers.get_named_qubits()\n", + " walk_regs = infra.get_named_qubits(walk.registers)\n", " reflect_regs = {k:v for k, v in walk_regs.items() if k in reflect.registers}\n", " \n", " reflect_controlled = reflect.controlled(control_values=[0])\n", diff --git a/cirq-ft/cirq_ft/algos/prepare_uniform_superposition.py b/cirq-ft/cirq_ft/algos/prepare_uniform_superposition.py index e75f735bfe9..374415e90bc 100644 --- a/cirq-ft/cirq_ft/algos/prepare_uniform_superposition.py +++ b/cirq-ft/cirq_ft/algos/prepare_uniform_superposition.py @@ -46,7 +46,9 @@ class PrepareUniformSuperposition(infra.GateWithRegisters): """ n: int - cv: Tuple[int, ...] = attr.field(converter=infra.to_tuple, default=()) + cv: Tuple[int, ...] = attr.field( + converter=lambda v: (v,) if isinstance(v, int) else tuple(v), default=() + ) @cached_property def registers(self) -> infra.Registers: diff --git a/cirq-ft/cirq_ft/algos/prepare_uniform_superposition_test.py b/cirq-ft/cirq_ft/algos/prepare_uniform_superposition_test.py index b89ac0bb698..f58cc671e63 100644 --- a/cirq-ft/cirq_ft/algos/prepare_uniform_superposition_test.py +++ b/cirq-ft/cirq_ft/algos/prepare_uniform_superposition_test.py @@ -14,6 +14,7 @@ import cirq import cirq_ft +from cirq_ft import infra import numpy as np import pytest @@ -51,7 +52,7 @@ def test_prepare_uniform_superposition_t_complexity(n: int): result = cirq_ft.t_complexity(gate) # TODO(#233): Controlled-H is currently counted as a separate rotation, but it can be # implemented using 2 T-gates. - assert result.rotations <= 2 + 2 * gate.registers.total_bits() + assert result.rotations <= 2 + 2 * infra.total_bits(gate.registers) assert result.t <= 12 * (n - 1).bit_length() diff --git a/cirq-ft/cirq_ft/algos/programmable_rotation_gate_array.py b/cirq-ft/cirq_ft/algos/programmable_rotation_gate_array.py index 47e43857247..158ec1a112d 100644 --- a/cirq-ft/cirq_ft/algos/programmable_rotation_gate_array.py +++ b/cirq-ft/cirq_ft/algos/programmable_rotation_gate_array.py @@ -102,22 +102,20 @@ def interleaved_unitary( pass @cached_property - def selection_registers(self) -> infra.SelectionRegisters: - return infra.SelectionRegisters( - [infra.SelectionRegister('selection', self._selection_bitsize, len(self.angles[0]))] - ) + def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: + return (infra.SelectionRegister('selection', self._selection_bitsize, len(self.angles[0])),) @cached_property - def kappa_load_target(self) -> infra.Registers: - return infra.Registers.build(kappa_load_target=self.kappa) + def kappa_load_target(self) -> Tuple[infra.Register, ...]: + return (infra.Register('kappa_load_target', self.kappa),) @cached_property - def rotations_target(self) -> infra.Registers: - return infra.Registers.build(rotations_target=self._target_bitsize) + def rotations_target(self) -> Tuple[infra.Register, ...]: + return (infra.Register('rotations_target', self._target_bitsize),) @property @abc.abstractmethod - def interleaved_unitary_target(self) -> infra.Registers: + def interleaved_unitary_target(self) -> Tuple[infra.Register, ...]: pass @cached_property @@ -195,7 +193,7 @@ def __init__( ): super().__init__(*angles, kappa=kappa, rotation_gate=rotation_gate) if not interleaved_unitaries: - identity_gate = cirq.IdentityGate(self.rotations_target.total_bits()) + identity_gate = cirq.IdentityGate(infra.total_bits(self.rotations_target)) interleaved_unitaries = (identity_gate,) * (len(angles) - 1) assert len(interleaved_unitaries) == len(angles) - 1 assert all(cirq.num_qubits(u) == self._target_bitsize for u in interleaved_unitaries) @@ -205,5 +203,5 @@ def interleaved_unitary(self, index: int, **qubit_regs: NDArray[cirq.Qid]) -> ci return self._interleaved_unitaries[index].on(*qubit_regs['rotations_target']) @cached_property - def interleaved_unitary_target(self) -> infra.Registers: - return infra.Registers.build() + def interleaved_unitary_target(self) -> Tuple[infra.Register, ...]: + return () diff --git a/cirq-ft/cirq_ft/algos/programmable_rotation_gate_array_test.py b/cirq-ft/cirq_ft/algos/programmable_rotation_gate_array_test.py index a860bb07b6e..cdd4212bcc3 100644 --- a/cirq-ft/cirq_ft/algos/programmable_rotation_gate_array_test.py +++ b/cirq-ft/cirq_ft/algos/programmable_rotation_gate_array_test.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Tuple from numpy.typing import NDArray import cirq import cirq_ft import numpy as np import pytest +from cirq_ft import infra from cirq._compat import cached_property from cirq_ft.infra.bit_tools import iter_bits @@ -33,8 +35,8 @@ def interleaved_unitary( return two_qubit_ops_factory[index % 2] @cached_property - def interleaved_unitary_target(self) -> cirq_ft.Registers: - return cirq_ft.Registers.build(unrelated_target=1) + def interleaved_unitary_target(self) -> Tuple[cirq_ft.Register, ...]: + return tuple(cirq_ft.Registers.build(unrelated_target=1)) def construct_custom_prga(*args, **kwargs) -> cirq_ft.ProgrammableRotationGateArrayBase: @@ -78,7 +80,7 @@ def test_programmable_rotation_gate_array(angles, kappa, constructor): *programmable_rotation_gate.interleaved_unitary_target, ] ) - rotations_and_unitary_qubits = rotations_and_unitary_registers.merge_qubits(**g.quregs) + rotations_and_unitary_qubits = infra.merge_qubits(rotations_and_unitary_registers, **g.quregs) # Build circuit. simulator = cirq.Simulator(dtype=np.complex128) diff --git a/cirq-ft/cirq_ft/algos/qrom.py b/cirq-ft/cirq_ft/algos/qrom.py index fdbe36792cc..8d09d82ed9b 100644 --- a/cirq-ft/cirq_ft/algos/qrom.py +++ b/cirq-ft/cirq_ft/algos/qrom.py @@ -92,36 +92,26 @@ def __attrs_post_init__(self): assert isinstance(self.target_bitsizes, tuple) @cached_property - def control_registers(self) -> infra.Registers: - return ( - infra.Registers.build(control=self.num_controls) - if self.num_controls - else infra.Registers([]) - ) + def control_registers(self) -> Tuple[infra.Register, ...]: + return () if not self.num_controls else (infra.Register('control', self.num_controls),) @cached_property - def selection_registers(self) -> infra.SelectionRegisters: + def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: if len(self.data[0].shape) == 1: - return infra.SelectionRegisters( - [ - infra.SelectionRegister( - 'selection', self.selection_bitsizes[0], self.data[0].shape[0] - ) - ] + return ( + infra.SelectionRegister( + 'selection', self.selection_bitsizes[0], self.data[0].shape[0] + ), ) else: - return infra.SelectionRegisters( - [ - infra.SelectionRegister(f'selection{i}', sb, len) - for i, (len, sb) in enumerate(zip(self.data[0].shape, self.selection_bitsizes)) - ] + return tuple( + infra.SelectionRegister(f'selection{i}', sb, l) + for i, (l, sb) in enumerate(zip(self.data[0].shape, self.selection_bitsizes)) ) @cached_property - def target_registers(self) -> infra.Registers: - return infra.Registers.build( - **{f'target{i}': len for i, len in enumerate(self.target_bitsizes)} - ) + def target_registers(self) -> Tuple[infra.Register, ...]: + return tuple(infra.Register(f'target{i}', l) for i, l in enumerate(self.target_bitsizes)) def __repr__(self) -> str: data_repr = f"({','.join(cirq._compat.proper_repr(d) for d in self.data)})" @@ -147,8 +137,8 @@ def _load_nth_data( def decompose_zero_selection( self, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] ) -> cirq.OP_TREE: - controls = self.control_registers.merge_qubits(**quregs) - target_regs = {k: v for k, v in quregs.items() if k in self.target_registers} + controls = infra.merge_qubits(self.control_registers, **quregs) + target_regs = {reg.name: quregs[reg.name] for reg in self.target_registers} zero_indx = (0,) * len(self.data[0].shape) if self.num_controls == 0: yield self._load_nth_data(zero_indx, cirq.X, **target_regs) @@ -181,7 +171,7 @@ def nth_operation( def _circuit_diagram_info_(self, _) -> cirq.CircuitDiagramInfo: wire_symbols = ["@"] * self.num_controls - wire_symbols += ["In"] * self.selection_registers.total_bits() + wire_symbols += ["In"] * infra.total_bits(self.selection_registers) for i, target in enumerate(self.target_registers): wire_symbols += [f"QROM_{i}"] * target.total_bits() return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols) diff --git a/cirq-ft/cirq_ft/algos/qrom_test.py b/cirq-ft/cirq_ft/algos/qrom_test.py index 01025ac38c5..514e7f03935 100644 --- a/cirq-ft/cirq_ft/algos/qrom_test.py +++ b/cirq-ft/cirq_ft/algos/qrom_test.py @@ -18,6 +18,7 @@ import cirq_ft import numpy as np import pytest +from cirq_ft import infra from cirq_ft.infra.bit_tools import iter_bits from cirq_ft.infra.jupyter_tools import execute_notebook @@ -34,7 +35,8 @@ def test_qrom_1d(data, num_controls): inverse = cirq.Circuit(cirq.decompose(g.operation**-1, context=g.context)) assert ( - len(inverse.all_qubits()) <= g.r.total_bits() + g.r['selection'].total_bits() + num_controls + len(inverse.all_qubits()) + <= infra.total_bits(g.r) + g.r['selection'].total_bits() + num_controls ) assert inverse.all_qubits() == decomposed_circuit.all_qubits() @@ -73,7 +75,7 @@ def test_qrom_diagram(): d1 = np.array([4, 5, 6]) qrom = cirq_ft.QROM.build(d0, d1) q = cirq.LineQubit.range(cirq.num_qubits(qrom)) - circuit = cirq.Circuit(qrom.on_registers(**qrom.registers.split_qubits(q))) + circuit = cirq.Circuit(qrom.on_registers(**infra.split_qubits(qrom.registers, q))) cirq.testing.assert_has_diagram( circuit, """ @@ -213,7 +215,7 @@ def test_qrom_multi_dim(data, num_controls): assert ( len(inverse.all_qubits()) - <= g.r.total_bits() + qrom.selection_registers.total_bits() + num_controls + <= infra.total_bits(g.r) + infra.total_bits(qrom.selection_registers) + num_controls ) assert inverse.all_qubits() == decomposed_circuit.all_qubits() diff --git a/cirq-ft/cirq_ft/algos/qubitization_walk_operator.py b/cirq-ft/cirq_ft/algos/qubitization_walk_operator.py index 6910de45905..f39964af93f 100644 --- a/cirq-ft/cirq_ft/algos/qubitization_walk_operator.py +++ b/cirq-ft/cirq_ft/algos/qubitization_walk_operator.py @@ -61,15 +61,15 @@ def __attrs_post_init__(self): assert self.select.control_registers == self.reflect.control_registers @cached_property - def control_registers(self) -> infra.Registers: + def control_registers(self) -> Tuple[infra.Register, ...]: return self.select.control_registers @cached_property - def selection_registers(self) -> infra.SelectionRegisters: + def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: return self.prepare.selection_registers @cached_property - def target_registers(self) -> infra.Registers: + def target_registers(self) -> Tuple[infra.Register, ...]: return self.select.target_registers @cached_property @@ -99,8 +99,12 @@ def decompose_from_registers( yield reflect_op def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo: - wire_symbols = ['@' if self.control_val else '@(0)'] * self.control_registers.total_bits() - wire_symbols += ['W'] * (self.registers.total_bits() - self.control_registers.total_bits()) + wire_symbols = ['@' if self.control_val else '@(0)'] * infra.total_bits( + self.control_registers + ) + wire_symbols += ['W'] * ( + infra.total_bits(self.registers) - infra.total_bits(self.control_registers) + ) wire_symbols[-1] = f'W^{self.power}' if self.power != 1 else 'W' return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols) diff --git a/cirq-ft/cirq_ft/algos/qubitization_walk_operator_test.py b/cirq-ft/cirq_ft/algos/qubitization_walk_operator_test.py index 8ea413661da..9b54501e99c 100644 --- a/cirq-ft/cirq_ft/algos/qubitization_walk_operator_test.py +++ b/cirq-ft/cirq_ft/algos/qubitization_walk_operator_test.py @@ -16,6 +16,7 @@ import cirq_ft import numpy as np import pytest +from cirq_ft import infra from cirq_ft.algos.generic_select_test import get_1d_Ising_hamiltonian from cirq_ft.algos.reflection_using_prepare_test import greedily_allocate_ancilla, keep from cirq_ft.infra.jupyter_tools import execute_notebook @@ -31,7 +32,9 @@ def walk_operator_for_pauli_hamiltonian( ham_coeff, probability_epsilon=eps ) select = cirq_ft.GenericSelect( - prepare.selection_registers.total_bits(), select_unitaries=ham_dps, target_bitsize=len(q) + infra.total_bits(prepare.selection_registers), + select_unitaries=ham_dps, + target_bitsize=len(q), ) return cirq_ft.QubitizationWalkOperator(select=select, prepare=prepare) @@ -96,7 +99,7 @@ def test_qubitization_walk_operator_diagrams(): num_sites, eps = 4, 1e-1 walk = get_walk_operator_for_1d_Ising_model(num_sites, eps) # 1. Diagram for $W = SELECT.R_{L}$ - qu_regs = walk.registers.get_named_qubits() + qu_regs = infra.get_named_qubits(walk.registers) walk_op = walk.on_registers(**qu_regs) circuit = cirq.Circuit(cirq.decompose_once(walk_op)) cirq.testing.assert_has_diagram( @@ -214,7 +217,7 @@ def keep(op): def test_qubitization_walk_operator_consistent_protocols_and_controlled(): gate = get_walk_operator_for_1d_Ising_model(4, 1e-1) - op = gate.on_registers(**gate.registers.get_named_qubits()) + op = gate.on_registers(**infra.get_named_qubits(gate.registers)) # Test consistent repr cirq.testing.assert_equivalent_repr( gate, setup_code='import cirq\nimport cirq_ft\nimport numpy as np' diff --git a/cirq-ft/cirq_ft/algos/reflection_using_prepare.py b/cirq-ft/cirq_ft/algos/reflection_using_prepare.py index 361644fc5ee..980465f524d 100644 --- a/cirq-ft/cirq_ft/algos/reflection_using_prepare.py +++ b/cirq-ft/cirq_ft/algos/reflection_using_prepare.py @@ -57,12 +57,11 @@ class ReflectionUsingPrepare(infra.GateWithRegisters): control_val: Optional[int] = None @cached_property - def control_registers(self) -> infra.Registers: - registers = [] if self.control_val is None else [infra.Register('control', 1)] - return infra.Registers(registers) + def control_registers(self) -> Tuple[infra.Register, ...]: + return () if self.control_val is None else (infra.Register('control', 1),) @cached_property - def selection_registers(self) -> infra.SelectionRegisters: + def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: return self.prepare_gate.selection_registers @cached_property @@ -87,7 +86,7 @@ def decompose_from_registers( # 1. PREPARE† yield cirq.inverse(prepare_op) # 2. MultiControlled Z, controlled on |000..00> state. - phase_control = self.selection_registers.merge_qubits(**state_prep_selection_regs) + phase_control = infra.merge_qubits(self.selection_registers, **state_prep_selection_regs) yield cirq.X(phase_target) if not self.control_val else [] yield mcmt.MultiControlPauli([0] * len(phase_control), target_gate=cirq.Z).on_registers( controls=phase_control, target=phase_target @@ -102,8 +101,10 @@ def decompose_from_registers( qm.qfree([phase_target]) def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo: - wire_symbols = ['@' if self.control_val else '@(0)'] * self.control_registers.total_bits() - wire_symbols += ['R_L'] * self.selection_registers.total_bits() + wire_symbols = ['@' if self.control_val else '@(0)'] * infra.total_bits( + self.control_registers + ) + wire_symbols += ['R_L'] * infra.total_bits(self.selection_registers) return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols) def __repr__(self): diff --git a/cirq-ft/cirq_ft/algos/reflection_using_prepare_test.py b/cirq-ft/cirq_ft/algos/reflection_using_prepare_test.py index 138415bfd35..b4a74c56ea7 100644 --- a/cirq-ft/cirq_ft/algos/reflection_using_prepare_test.py +++ b/cirq-ft/cirq_ft/algos/reflection_using_prepare_test.py @@ -16,6 +16,7 @@ import cirq import cirq_ft +from cirq_ft import infra import numpy as np import pytest @@ -108,7 +109,7 @@ def test_reflection_using_prepare_diagram(): ) # No control gate = cirq_ft.ReflectionUsingPrepare(prepare_gate, control_val=None) - op = gate.on_registers(**gate.registers.get_named_qubits()) + op = gate.on_registers(**infra.get_named_qubits(gate.registers)) circuit = greedily_allocate_ancilla(cirq.Circuit(cirq.decompose_once(op))) cirq.testing.assert_has_diagram( circuit, @@ -138,7 +139,7 @@ def test_reflection_using_prepare_diagram(): # Control on `|1>` state gate = cirq_ft.ReflectionUsingPrepare(prepare_gate, control_val=1) - op = gate.on_registers(**gate.registers.get_named_qubits()) + op = gate.on_registers(**infra.get_named_qubits(gate.registers)) circuit = greedily_allocate_ancilla(cirq.Circuit(cirq.decompose_once(op))) cirq.testing.assert_has_diagram( circuit, @@ -167,7 +168,7 @@ def test_reflection_using_prepare_diagram(): # Control on `|0>` state gate = cirq_ft.ReflectionUsingPrepare(prepare_gate, control_val=0) - op = gate.on_registers(**gate.registers.get_named_qubits()) + op = gate.on_registers(**infra.get_named_qubits(gate.registers)) circuit = greedily_allocate_ancilla(cirq.Circuit(cirq.decompose_once(op))) cirq.testing.assert_has_diagram( circuit, @@ -203,7 +204,7 @@ def test_reflection_using_prepare_consistent_protocols_and_controlled(): ) # No control gate = cirq_ft.ReflectionUsingPrepare(prepare_gate, control_val=None) - op = gate.on_registers(**gate.registers.get_named_qubits()) + op = gate.on_registers(**infra.get_named_qubits(gate.registers)) # Test consistent repr cirq.testing.assert_equivalent_repr( gate, setup_code='import cirq\nimport cirq_ft\nimport numpy as np' diff --git a/cirq-ft/cirq_ft/algos/select_and_prepare.py b/cirq-ft/cirq_ft/algos/select_and_prepare.py index b85fbccdfb3..72958b7fb4f 100644 --- a/cirq-ft/cirq_ft/algos/select_and_prepare.py +++ b/cirq-ft/cirq_ft/algos/select_and_prepare.py @@ -13,6 +13,7 @@ # limitations under the License. import abc +from typing import Tuple from cirq._compat import cached_property from cirq_ft import infra @@ -38,17 +39,17 @@ class SelectOracle(infra.GateWithRegisters): @property @abc.abstractmethod - def control_registers(self) -> infra.Registers: + def control_registers(self) -> Tuple[infra.Register, ...]: ... @property @abc.abstractmethod - def selection_registers(self) -> infra.SelectionRegisters: + def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: ... @property @abc.abstractmethod - def target_registers(self) -> infra.Registers: + def target_registers(self) -> Tuple[infra.Register, ...]: ... @cached_property @@ -75,12 +76,12 @@ class PrepareOracle(infra.GateWithRegisters): @property @abc.abstractmethod - def selection_registers(self) -> infra.SelectionRegisters: + def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: ... @cached_property - def junk_registers(self) -> infra.Registers: - return infra.Registers([]) + def junk_registers(self) -> Tuple[infra.Register, ...]: + return () @cached_property def registers(self) -> infra.Registers: diff --git a/cirq-ft/cirq_ft/algos/select_swap_qrom.py b/cirq-ft/cirq_ft/algos/select_swap_qrom.py index 4cde2b1f172..248dd1ab4f5 100644 --- a/cirq-ft/cirq_ft/algos/select_swap_qrom.py +++ b/cirq-ft/cirq_ft/algos/select_swap_qrom.py @@ -138,21 +138,19 @@ def __init__( self._data = tuple(tuple(d) for d in data) @cached_property - def selection_registers(self) -> infra.SelectionRegisters: - return infra.SelectionRegisters( - [ - infra.SelectionRegister( - 'selection', self.selection_q + self.selection_r, self._iteration_length - ) - ] + def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: + return ( + infra.SelectionRegister( + 'selection', self.selection_q + self.selection_r, self._iteration_length + ), ) @cached_property - def target_registers(self) -> infra.Registers: - clean_output = {} - for sequence_id in range(self._num_sequences): - clean_output[f'target{sequence_id}'] = self._target_bitsizes[sequence_id] - return infra.Registers.build(**clean_output) + def target_registers(self) -> Tuple[infra.Register, ...]: + return tuple( + infra.Register(f'target{sequence_id}', self._target_bitsizes[sequence_id]) + for sequence_id in range(self._num_sequences) + ) @cached_property def registers(self) -> infra.Registers: @@ -212,15 +210,16 @@ def decompose_from_registers( target_bitsizes=tuple(qrom_target_bitsizes), ) qrom_op = qrom_gate.on_registers( - selection=q, **qrom_gate.target_registers.split_qubits(ordered_target_qubits) + selection=q, **infra.split_qubits(qrom_gate.target_registers, ordered_target_qubits) ) swap_with_zero_gate = swap_network.SwapWithZeroGate( - k, self.target_registers.total_bits(), self.block_size + k, infra.total_bits(self.target_registers), self.block_size ) swap_with_zero_op = swap_with_zero_gate.on_registers( - selection=r, **swap_with_zero_gate.target_registers.split_qubits(ordered_target_qubits) + selection=r, + **infra.split_qubits(swap_with_zero_gate.target_registers, ordered_target_qubits), ) - clean_targets = self.target_registers.merge_qubits(**targets) + clean_targets = infra.merge_qubits(self.target_registers, **targets) cnot_op = cirq.Moment(cirq.CNOT(s, t) for s, t in zip(ordered_target_qubits, clean_targets)) # Yield the operations in correct order. yield qrom_op diff --git a/cirq-ft/cirq_ft/algos/select_swap_qrom_test.py b/cirq-ft/cirq_ft/algos/select_swap_qrom_test.py index f040f312bbf..85c64425d63 100644 --- a/cirq-ft/cirq_ft/algos/select_swap_qrom_test.py +++ b/cirq-ft/cirq_ft/algos/select_swap_qrom_test.py @@ -16,6 +16,7 @@ import cirq_ft import numpy as np import pytest +from cirq_ft import infra from cirq_ft.infra.bit_tools import iter_bits @@ -23,7 +24,7 @@ @pytest.mark.parametrize("block_size", [None, 1, 2, 3]) def test_select_swap_qrom(data, block_size): qrom = cirq_ft.SelectSwapQROM(*data, block_size=block_size) - qubit_regs = qrom.registers.get_named_qubits() + qubit_regs = infra.get_named_qubits(qrom.registers) selection = qubit_regs["selection"] selection_q, selection_r = selection[: qrom.selection_q], selection[qrom.selection_q :] targets = [qubit_regs[f"target{i}"] for i in range(len(data))] @@ -47,7 +48,7 @@ def test_select_swap_qrom(data, block_size): cirq.H.on_each(*dirty_target_ancilla), ) all_qubits = sorted(circuit.all_qubits()) - for selection_integer in range(qrom.selection_registers.iteration_lengths[0]): + for selection_integer in range(qrom.selection_registers[0].iteration_length): svals_q = list(iter_bits(selection_integer // qrom.block_size, len(selection_q))) svals_r = list(iter_bits(selection_integer % qrom.block_size, len(selection_r))) qubit_vals = {x: 0 for x in all_qubits} @@ -77,7 +78,7 @@ def test_qroam_diagram(): blocksize = 2 qrom = cirq_ft.SelectSwapQROM(*data, block_size=blocksize) q = cirq.LineQubit.range(cirq.num_qubits(qrom)) - circuit = cirq.Circuit(qrom.on_registers(**qrom.registers.split_qubits(q))) + circuit = cirq.Circuit(qrom.on_registers(**infra.split_qubits(qrom.registers, q))) cirq.testing.assert_has_diagram( circuit, """ diff --git a/cirq-ft/cirq_ft/algos/selected_majorana_fermion.py b/cirq-ft/cirq_ft/algos/selected_majorana_fermion.py index 501b23bb786..877c81f39a3 100644 --- a/cirq-ft/cirq_ft/algos/selected_majorana_fermion.py +++ b/cirq-ft/cirq_ft/algos/selected_majorana_fermion.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence, Union +from typing import Sequence, Union, Tuple from numpy.typing import NDArray import attr @@ -34,7 +34,7 @@ class SelectedMajoranaFermionGate(unary_iteration_gate.UnaryIterationGate): Args: - selection_regs: Indexing `select` registers of type `SelectionRegisters`. It also contains + selection_regs: Indexing `select` registers of type `SelectionRegister`. It also contains information about the iteration length of each selection register. control_regs: Control registers for constructing a controlled version of the gate. target_gate: Single qubit gate to be applied to the target qubits. @@ -43,8 +43,13 @@ class SelectedMajoranaFermionGate(unary_iteration_gate.UnaryIterationGate): See Fig 9 of https://arxiv.org/abs/1805.03662 for more details. """ - selection_regs: infra.SelectionRegisters - control_regs: infra.Registers = infra.Registers.build(control=1) + selection_regs: Tuple[infra.SelectionRegister, ...] = attr.field( + converter=lambda v: (v,) if isinstance(v, infra.SelectionRegister) else tuple(v) + ) + control_regs: Tuple[infra.Register, ...] = attr.field( + converter=lambda v: (v,) if isinstance(v, infra.Register) else tuple(v), + default=(infra.Register('control', 1),), + ) target_gate: cirq.Gate = cirq.Y @classmethod @@ -55,38 +60,39 @@ def make_on( **quregs: Union[Sequence[cirq.Qid], NDArray[cirq.Qid]], # type: ignore[type-var] ) -> cirq.Operation: """Helper constructor to automatically deduce selection_regs attribute.""" - return cls( - selection_regs=infra.SelectionRegisters( - [ - infra.SelectionRegister( - 'selection', len(quregs['selection']), len(quregs['target']) - ) - ] + return SelectedMajoranaFermionGate( + selection_regs=infra.SelectionRegister( + 'selection', len(quregs['selection']), len(quregs['target']) ), target_gate=target_gate, ).on_registers(**quregs) @cached_property - def control_registers(self) -> infra.Registers: + def control_registers(self) -> Tuple[infra.Register, ...]: return self.control_regs @cached_property - def selection_registers(self) -> infra.SelectionRegisters: + def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: return self.selection_regs @cached_property - def target_registers(self) -> infra.Registers: - return infra.Registers.build(target=self.selection_regs.total_iteration_size) + def target_registers(self) -> Tuple[infra.Register, ...]: + total_iteration_size = np.product( + tuple(reg.iteration_length for reg in self.selection_registers) + ) + return (infra.Register('target', int(total_iteration_size)),) @cached_property - def extra_registers(self) -> infra.Registers: - return infra.Registers.build(accumulator=1) + def extra_registers(self) -> Tuple[infra.Register, ...]: + return (infra.Register('accumulator', 1),) def decompose_from_registers( self, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] ) -> cirq.OP_TREE: quregs['accumulator'] = np.array(context.qubit_manager.qalloc(1)) - control = quregs[self.control_regs[0].name] if self.control_registers.total_bits() else [] + control = ( + quregs[self.control_regs[0].name] if infra.total_bits(self.control_registers) else [] + ) yield cirq.X(*quregs['accumulator']).controlled_by(*control) yield super(SelectedMajoranaFermionGate, self).decompose_from_registers( context=context, **quregs @@ -94,9 +100,9 @@ def decompose_from_registers( context.qubit_manager.qfree(quregs['accumulator']) def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo: - wire_symbols = ["@"] * self.control_registers.total_bits() - wire_symbols += ["In"] * self.selection_registers.total_bits() - wire_symbols += [f"Z{self.target_gate}"] * self.target_registers.total_bits() + wire_symbols = ["@"] * infra.total_bits(self.control_registers) + wire_symbols += ["In"] * infra.total_bits(self.selection_registers) + wire_symbols += [f"Z{self.target_gate}"] * infra.total_bits(self.target_registers) return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols) def nth_operation( # type: ignore[override] @@ -107,8 +113,9 @@ def nth_operation( # type: ignore[override] accumulator: Sequence[cirq.Qid], **selection_indices: int, ) -> cirq.OP_TREE: + selection_shape = tuple(reg.iteration_length for reg in self.selection_regs) selection_idx = tuple(selection_indices[reg.name] for reg in self.selection_regs) - target_idx = self.selection_registers.to_flat_idx(*selection_idx) + target_idx = int(np.ravel_multi_index(selection_idx, selection_shape)) yield cirq.CNOT(control, *accumulator) yield self.target_gate(target[target_idx]).controlled_by(control) yield cirq.CZ(*accumulator, target[target_idx]) diff --git a/cirq-ft/cirq_ft/algos/selected_majorana_fermion_test.py b/cirq-ft/cirq_ft/algos/selected_majorana_fermion_test.py index cb674c51cd2..9367bcc607f 100644 --- a/cirq-ft/cirq_ft/algos/selected_majorana_fermion_test.py +++ b/cirq-ft/cirq_ft/algos/selected_majorana_fermion_test.py @@ -16,6 +16,7 @@ import cirq_ft import numpy as np import pytest +from cirq_ft import infra from cirq_ft.infra.bit_tools import iter_bits @@ -23,13 +24,11 @@ @pytest.mark.parametrize("target_gate", [cirq.X, cirq.Y]) def test_selected_majorana_fermion_gate(selection_bitsize, target_bitsize, target_gate): gate = cirq_ft.SelectedMajoranaFermionGate( - cirq_ft.SelectionRegisters( - [cirq_ft.SelectionRegister('selection', selection_bitsize, target_bitsize)] - ), + cirq_ft.SelectionRegister('selection', selection_bitsize, target_bitsize), target_gate=target_gate, ) g = cirq_ft.testing.GateHelper(gate) - assert len(g.all_qubits) <= gate.registers.total_bits() + selection_bitsize + 1 + assert len(g.all_qubits) <= infra.total_bits(gate.registers) + selection_bitsize + 1 sim = cirq.Simulator(dtype=np.complex128) for n in range(target_bitsize): @@ -65,13 +64,11 @@ def test_selected_majorana_fermion_gate(selection_bitsize, target_bitsize, targe def test_selected_majorana_fermion_gate_diagram(): selection_bitsize, target_bitsize = 3, 5 gate = cirq_ft.SelectedMajoranaFermionGate( - cirq_ft.SelectionRegisters( - [cirq_ft.SelectionRegister('selection', selection_bitsize, target_bitsize)] - ), + cirq_ft.SelectionRegister('selection', selection_bitsize, target_bitsize), target_gate=cirq.X, ) - circuit = cirq.Circuit(gate.on_registers(**gate.registers.get_named_qubits())) - qubits = list(q for v in gate.registers.get_named_qubits().values() for q in v) + circuit = cirq.Circuit(gate.on_registers(**infra.get_named_qubits(gate.registers))) + qubits = list(q for v in infra.get_named_qubits(gate.registers).values() for q in v) cirq.testing.assert_has_diagram( circuit, """ @@ -100,9 +97,7 @@ def test_selected_majorana_fermion_gate_diagram(): def test_selected_majorana_fermion_gate_decomposed_diagram(): selection_bitsize, target_bitsize = 2, 3 gate = cirq_ft.SelectedMajoranaFermionGate( - cirq_ft.SelectionRegisters( - [cirq_ft.SelectionRegister('selection', selection_bitsize, target_bitsize)] - ), + cirq_ft.SelectionRegister('selection', selection_bitsize, target_bitsize), target_gate=cirq.X, ) greedy_mm = cirq_ft.GreedyQubitManager(prefix="_a", maximize_reuse=True) @@ -145,13 +140,11 @@ def test_selected_majorana_fermion_gate_decomposed_diagram(): def test_selected_majorana_fermion_gate_make_on(): selection_bitsize, target_bitsize = 3, 5 gate = cirq_ft.SelectedMajoranaFermionGate( - cirq_ft.SelectionRegisters( - [cirq_ft.SelectionRegister('selection', selection_bitsize, target_bitsize)] - ), + cirq_ft.SelectionRegister('selection', selection_bitsize, target_bitsize), target_gate=cirq.X, ) - op = gate.on_registers(**gate.registers.get_named_qubits()) + op = gate.on_registers(**infra.get_named_qubits(gate.registers)) op2 = cirq_ft.SelectedMajoranaFermionGate.make_on( - target_gate=cirq.X, **gate.registers.get_named_qubits() + target_gate=cirq.X, **infra.get_named_qubits(gate.registers) ) assert op == op2 diff --git a/cirq-ft/cirq_ft/algos/state_preparation.py b/cirq-ft/cirq_ft/algos/state_preparation.py index cb7fd397b03..bec54f50a6b 100644 --- a/cirq-ft/cirq_ft/algos/state_preparation.py +++ b/cirq-ft/cirq_ft/algos/state_preparation.py @@ -20,7 +20,7 @@ largest absolute error that one can tolerate in the prepared amplitudes. """ -from typing import List +from typing import List, Tuple from numpy.typing import NDArray import attr @@ -83,7 +83,9 @@ class StatePreparationAliasSampling(select_and_prepare.PrepareOracle): (https://arxiv.org/abs/1805.03662). Babbush et. al. (2018). Section III.D. and Figure 11. """ - selection_registers: infra.SelectionRegisters + selection_registers: Tuple[infra.SelectionRegister, ...] = attr.field( + converter=lambda v: (v,) if isinstance(v, infra.SelectionRegister) else tuple(v) + ) alt: NDArray[np.int_] keep: NDArray[np.int_] mu: int @@ -106,9 +108,7 @@ def from_lcu_probs( ) N = len(lcu_probabilities) return StatePreparationAliasSampling( - selection_registers=infra.SelectionRegisters( - [infra.SelectionRegister('selection', (N - 1).bit_length(), N)] - ), + selection_registers=infra.SelectionRegister('selection', (N - 1).bit_length(), N), alt=np.array(alt), keep=np.array(keep), mu=mu, @@ -120,7 +120,7 @@ def sigma_mu_bitsize(self) -> int: @cached_property def alternates_bitsize(self) -> int: - return self.selection_registers.total_bits() + return infra.total_bits(self.selection_registers) @cached_property def keep_bitsize(self) -> int: @@ -128,15 +128,17 @@ def keep_bitsize(self) -> int: @cached_property def selection_bitsize(self) -> int: - return self.selection_registers.total_bits() + return infra.total_bits(self.selection_registers) @cached_property - def junk_registers(self) -> infra.Registers: - return infra.Registers.build( - sigma_mu=self.sigma_mu_bitsize, - alt=self.alternates_bitsize, - keep=self.keep_bitsize, - less_than_equal=1, + def junk_registers(self) -> Tuple[infra.Register, ...]: + return tuple( + infra.Registers.build( + sigma_mu=self.sigma_mu_bitsize, + alt=self.alternates_bitsize, + keep=self.keep_bitsize, + less_than_equal=1, + ) ) def _value_equality_values_(self): diff --git a/cirq-ft/cirq_ft/algos/swap_network.py b/cirq-ft/cirq_ft/algos/swap_network.py index 480bf0e4f3c..279ab33be38 100644 --- a/cirq-ft/cirq_ft/algos/swap_network.py +++ b/cirq-ft/cirq_ft/algos/swap_network.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence, Union +from typing import Sequence, Union, Tuple from numpy.typing import NDArray import attr @@ -145,14 +145,14 @@ def __attrs_post_init__(self): assert self.n_target_registers <= 2**self.selection_bitsize @cached_property - def selection_registers(self) -> infra.SelectionRegisters: - return infra.SelectionRegisters( - [infra.SelectionRegister('selection', self.selection_bitsize, self.n_target_registers)] + def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: + return ( + infra.SelectionRegister('selection', self.selection_bitsize, self.n_target_registers), ) @cached_property - def target_registers(self) -> infra.Registers: - return infra.Registers.build(target=(self.n_target_registers, self.target_bitsize)) + def target_registers(self) -> Tuple[infra.Register, ...]: + return (infra.Register('target', (self.n_target_registers, self.target_bitsize)),) @cached_property def registers(self) -> infra.Registers: diff --git a/cirq-ft/cirq_ft/algos/swap_network_test.py b/cirq-ft/cirq_ft/algos/swap_network_test.py index c934ca4dec3..92cb5865a2a 100644 --- a/cirq-ft/cirq_ft/algos/swap_network_test.py +++ b/cirq-ft/cirq_ft/algos/swap_network_test.py @@ -18,6 +18,7 @@ import cirq_ft import numpy as np import pytest +from cirq_ft import infra from cirq_ft.infra.jupyter_tools import execute_notebook random.seed(12345) @@ -65,7 +66,7 @@ def test_swap_with_zero_gate(selection_bitsize, target_bitsize, n_target_registe def test_swap_with_zero_gate_diagram(): gate = cirq_ft.SwapWithZeroGate(3, 2, 4) q = cirq.LineQubit.range(cirq.num_qubits(gate)) - circuit = cirq.Circuit(gate.on_registers(**gate.registers.split_qubits(q))) + circuit = cirq.Circuit(gate.on_registers(**infra.split_qubits(gate.registers, q))) cirq.testing.assert_has_diagram( circuit, """ diff --git a/cirq-ft/cirq_ft/algos/unary_iteration.ipynb b/cirq-ft/cirq_ft/algos/unary_iteration.ipynb index 003941203c3..4eabc65a0af 100644 --- a/cirq-ft/cirq_ft/algos/unary_iteration.ipynb +++ b/cirq-ft/cirq_ft/algos/unary_iteration.ipynb @@ -471,7 +471,7 @@ "metadata": {}, "outputs": [], "source": [ - "from cirq_ft import Registers, SelectionRegister, SelectionRegisters, UnaryIterationGate\n", + "from cirq_ft import Register, Registers, SelectionRegister, UnaryIterationGate\n", "from cirq._compat import cached_property\n", "\n", "class ApplyXToLthQubit(UnaryIterationGate):\n", @@ -481,16 +481,16 @@ " self._control_bitsize = control_bitsize\n", "\n", " @cached_property\n", - " def control_registers(self) -> Registers:\n", - " return Registers.build(control=self._control_bitsize)\n", + " def control_registers(self) -> Tuple[Register, ...]:\n", + " return Register('control', self._control_bitsize),\n", "\n", " @cached_property\n", - " def selection_registers(self) -> SelectionRegisters:\n", - " return SelectionRegisters([SelectionRegister('selection', self._selection_bitsize, self._target_bitsize)])\n", + " def selection_registers(self) -> Tuple[SelectionRegister, ...]:\n", + " return SelectionRegister('selection', self._selection_bitsize, self._target_bitsize),\n", "\n", " @cached_property\n", - " def target_registers(self) -> Registers:\n", - " return Registers.build(target=self._target_bitsize)\n", + " def target_registers(self) -> Tuple[Register, ...]:\n", + " return Register('target', self._target_bitsize),\n", "\n", " def nth_operation(\n", " self, context, control: cirq.Qid, selection: int, target: Sequence[cirq.Qid]\n", diff --git a/cirq-ft/cirq_ft/algos/unary_iteration_gate.py b/cirq-ft/cirq_ft/algos/unary_iteration_gate.py index f37c2718b3b..d72ab7381ce 100644 --- a/cirq-ft/cirq_ft/algos/unary_iteration_gate.py +++ b/cirq-ft/cirq_ft/algos/unary_iteration_gate.py @@ -268,17 +268,17 @@ class UnaryIterationGate(infra.GateWithRegisters): @cached_property @abc.abstractmethod - def control_registers(self) -> infra.Registers: + def control_registers(self) -> Tuple[infra.Register, ...]: pass @cached_property @abc.abstractmethod - def selection_registers(self) -> infra.SelectionRegisters: + def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: pass @cached_property @abc.abstractmethod - def target_registers(self) -> infra.Registers: + def target_registers(self) -> Tuple[infra.Register, ...]: pass @cached_property @@ -288,8 +288,8 @@ def registers(self) -> infra.Registers: ) @cached_property - def extra_registers(self) -> infra.Registers: - return infra.Registers([]) + def extra_registers(self) -> Tuple[infra.Register, ...]: + return () @abc.abstractmethod def nth_operation( @@ -325,7 +325,7 @@ def decompose_zero_selection( By default, if the selection register is empty, the decomposition will raise a `NotImplementedError`. The derived classes can override this method and specify a custom decomposition that should be used if the selection register is empty, - i.e. `self.selection_registers.total_bits() == 0`. + i.e. `infra.total_bits(self.selection_registers) == 0`. The derived classes should specify the following arguments as `**kwargs`: 1) Register names in `self.control_registers`: Each argument corresponds to a @@ -366,14 +366,14 @@ def _break_early(self, selection_index_prefix: Tuple[int, ...], l: int, r: int) def decompose_from_registers( self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] ) -> cirq.OP_TREE: - if self.selection_registers.total_bits() == 0 or self._break_early( + if infra.total_bits(self.selection_registers) == 0 or self._break_early( (), 0, self.selection_registers[0].iteration_length ): return self.decompose_zero_selection(context=context, **quregs) num_loops = len(self.selection_registers) - target_regs = {k: v for k, v in quregs.items() if k in self.target_registers} - extra_regs = {k: v for k, v in quregs.items() if k in self.extra_registers} + target_regs = {reg.name: quregs[reg.name] for reg in self.target_registers} + extra_regs = {reg.name: quregs[reg.name] for reg in self.extra_registers} def unary_iteration_loops( nested_depth: int, @@ -430,7 +430,7 @@ def unary_iteration_loops( selection_reg_name_to_val.pop(self.selection_registers[nested_depth].name) yield ops - return unary_iteration_loops(0, {}, self.control_registers.merge_qubits(**quregs)) + return unary_iteration_loops(0, {}, infra.merge_qubits(self.control_registers, **quregs)) def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo: """Basic circuit diagram. @@ -438,7 +438,7 @@ def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.Circ Descendants are encouraged to override this with more descriptive circuit diagram information. """ - wire_symbols = ["@"] * self.control_registers.total_bits() - wire_symbols += ["In"] * self.selection_registers.total_bits() - wire_symbols += [self.__class__.__name__] * self.target_registers.total_bits() + wire_symbols = ["@"] * infra.total_bits(self.control_registers) + wire_symbols += ["In"] * infra.total_bits(self.selection_registers) + wire_symbols += [self.__class__.__name__] * infra.total_bits(self.target_registers) return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols) diff --git a/cirq-ft/cirq_ft/algos/unary_iteration_gate_test.py b/cirq-ft/cirq_ft/algos/unary_iteration_gate_test.py index ffa50adc940..0c754fc5980 100644 --- a/cirq-ft/cirq_ft/algos/unary_iteration_gate_test.py +++ b/cirq-ft/cirq_ft/algos/unary_iteration_gate_test.py @@ -19,6 +19,7 @@ import cirq_ft import pytest from cirq._compat import cached_property +from cirq_ft import infra from cirq_ft.infra.bit_tools import iter_bits from cirq_ft.infra.jupyter_tools import execute_notebook @@ -30,18 +31,18 @@ def __init__(self, selection_bitsize: int, target_bitsize: int, control_bitsize: self._control_bitsize = control_bitsize @cached_property - def control_registers(self) -> cirq_ft.Registers: - return cirq_ft.Registers.build(control=self._control_bitsize) + def control_registers(self) -> Tuple[cirq_ft.Register, ...]: + return (cirq_ft.Register('control', self._control_bitsize),) @cached_property - def selection_registers(self) -> cirq_ft.SelectionRegisters: - return cirq_ft.SelectionRegisters( - [cirq_ft.SelectionRegister('selection', self._selection_bitsize, self._target_bitsize)] + def selection_registers(self) -> Tuple[cirq_ft.SelectionRegister, ...]: + return ( + cirq_ft.SelectionRegister('selection', self._selection_bitsize, self._target_bitsize), ) @cached_property - def target_registers(self) -> cirq_ft.Registers: - return cirq_ft.Registers.build(target=self._target_bitsize) + def target_registers(self) -> Tuple[cirq_ft.Register, ...]: + return (cirq_ft.Register('target', self._target_bitsize),) def nth_operation( # type: ignore[override] self, @@ -83,24 +84,24 @@ def __init__(self, target_shape: Tuple[int, int, int]): self._target_shape = target_shape @cached_property - def control_registers(self) -> cirq_ft.Registers: - return cirq_ft.Registers([]) + def control_registers(self) -> Tuple[cirq_ft.Register, ...]: + return () @cached_property - def selection_registers(self) -> cirq_ft.SelectionRegisters: - return cirq_ft.SelectionRegisters( - [ - cirq_ft.SelectionRegister( - 'ijk'[i], (self._target_shape[i] - 1).bit_length(), self._target_shape[i] - ) - for i in range(3) - ] + def selection_registers(self) -> Tuple[cirq_ft.SelectionRegister, ...]: + return tuple( + cirq_ft.SelectionRegister( + 'ijk'[i], (self._target_shape[i] - 1).bit_length(), self._target_shape[i] + ) + for i in range(3) ) @cached_property - def target_registers(self) -> cirq_ft.Registers: - return cirq_ft.Registers.build( - t1=self._target_shape[0], t2=self._target_shape[1], t3=self._target_shape[2] + def target_registers(self) -> Tuple[cirq_ft.Register, ...]: + return tuple( + cirq_ft.Registers.build( + t1=self._target_shape[0], t2=self._target_shape[1], t3=self._target_shape[2] + ) ) def nth_operation( # type: ignore[override] @@ -123,7 +124,8 @@ def test_multi_dimensional_unary_iteration_gate(target_shape: Tuple[int, int, in gate = ApplyXToIJKthQubit(target_shape) g = cirq_ft.testing.GateHelper(gate, context=cirq.DecompositionContext(greedy_mm)) assert ( - len(g.all_qubits) <= gate.registers.total_bits() + gate.selection_registers.total_bits() - 1 + len(g.all_qubits) + <= infra.total_bits(gate.registers) + infra.total_bits(gate.selection_registers) - 1 ) max_i, max_j, max_k = target_shape @@ -147,10 +149,11 @@ def test_multi_dimensional_unary_iteration_gate(target_shape: Tuple[int, int, in def test_unary_iteration_loop(): n_range, m_range = (3, 5), (6, 8) - selection_registers = cirq_ft.SelectionRegisters( - [cirq_ft.SelectionRegister('n', 3, 5), cirq_ft.SelectionRegister('m', 3, 8)] - ) - selection = selection_registers.get_named_qubits() + selection_registers = [ + cirq_ft.SelectionRegister('n', 3, 5), + cirq_ft.SelectionRegister('m', 3, 8), + ] + selection = infra.get_named_qubits(selection_registers) target = {(n, m): cirq.q(f't({n}, {m})') for n in range(*n_range) for m in range(*m_range)} qm = cirq_ft.GreedyQubitManager("ancilla", maximize_reuse=True) circuit = cirq.Circuit() diff --git a/cirq-ft/cirq_ft/infra/__init__.py b/cirq-ft/cirq_ft/infra/__init__.py index bfc99572bef..02f503110ca 100644 --- a/cirq-ft/cirq_ft/infra/__init__.py +++ b/cirq-ft/cirq_ft/infra/__init__.py @@ -17,9 +17,11 @@ Register, Registers, SelectionRegister, - SelectionRegisters, + total_bits, + split_qubits, + merge_qubits, + get_named_qubits, ) from cirq_ft.infra.qubit_management_transformers import map_clean_and_borrowable_qubits from cirq_ft.infra.qubit_manager import GreedyQubitManager from cirq_ft.infra.t_complexity_protocol import TComplexity, t_complexity -from cirq_ft.infra.type_convertors import to_tuple diff --git a/cirq-ft/cirq_ft/infra/gate_with_registers.ipynb b/cirq-ft/cirq_ft/infra/gate_with_registers.ipynb index 70e4a6e59ba..6afb6d49d4f 100644 --- a/cirq-ft/cirq_ft/infra/gate_with_registers.ipynb +++ b/cirq-ft/cirq_ft/infra/gate_with_registers.ipynb @@ -49,7 +49,7 @@ "metadata": {}, "outputs": [], "source": [ - "from cirq_ft import Register, Registers\n", + "from cirq_ft import Register, Registers, infra\n", "\n", "control_reg = Register(name='control', shape=(2,))\n", "target_reg = Register(name='target', shape=(3,))\n", @@ -163,7 +163,7 @@ "outputs": [], "source": [ "r = gate.registers\n", - "quregs = r.get_named_qubits()\n", + "quregs = infra.get_named_qubits(r)\n", "quregs" ] }, diff --git a/cirq-ft/cirq_ft/infra/gate_with_registers.py b/cirq-ft/cirq_ft/infra/gate_with_registers.py index c65d3578902..7139b66a65a 100644 --- a/cirq-ft/cirq_ft/infra/gate_with_registers.py +++ b/cirq-ft/cirq_ft/infra/gate_with_registers.py @@ -43,7 +43,7 @@ def all_idxs(self) -> Iterable[Tuple[int, ...]]: def total_bits(self) -> int: """The total number of bits in this register. - This is the product of bitsize and each of the dimensions in `shape`. + This is the product of each of the dimensions in `shape`. """ return int(np.product(self.shape)) @@ -51,6 +51,68 @@ def __repr__(self): return f'cirq_ft.Register(name="{self.name}", shape={self.shape})' +def total_bits(registers: Iterable[Register]) -> int: + """Sum of `reg.total_bits()` for each register `reg` in input `registers`.""" + + return sum(reg.total_bits() for reg in registers) + + +def split_qubits( + registers: Iterable[Register], qubits: Sequence[cirq.Qid] +) -> Dict[str, NDArray[cirq.Qid]]: # type: ignore[type-var] + """Splits the flat list of qubits into a dictionary of appropriately shaped qubit arrays.""" + + qubit_regs = {} + base = 0 + for reg in registers: + qubit_regs[reg.name] = np.array(qubits[base : base + reg.total_bits()]).reshape(reg.shape) + base += reg.total_bits() + return qubit_regs + + +def merge_qubits( + registers: Iterable[Register], + **qubit_regs: Union[cirq.Qid, Sequence[cirq.Qid], NDArray[cirq.Qid]], +) -> List[cirq.Qid]: + """Merges the dictionary of appropriately shaped qubit arrays into a flat list of qubits.""" + + ret: List[cirq.Qid] = [] + for reg in registers: + if reg.name not in qubit_regs: + raise ValueError(f"All qubit registers must be present. {reg.name} not in qubit_regs") + qubits = qubit_regs[reg.name] + qubits = np.array([qubits] if isinstance(qubits, cirq.Qid) else qubits) + if qubits.shape != reg.shape: + raise ValueError( + f'{reg.name} register must of shape {reg.shape} but is of shape {qubits.shape}' + ) + ret += qubits.flatten().tolist() + return ret + + +def get_named_qubits(registers: Iterable[Register]) -> Dict[str, NDArray[cirq.Qid]]: + """Returns a dictionary of appropriately shaped named qubit registers for input `registers`.""" + + def _qubit_array(reg: Register): + qubits = np.empty(reg.shape, dtype=object) + for ii in reg.all_idxs(): + qubits[ii] = cirq.NamedQubit(f'{reg.name}[{", ".join(str(i) for i in ii)}]') + return qubits + + def _qubits_for_reg(reg: Register): + if len(reg.shape) > 1: + return _qubit_array(reg) + + return np.array( + [cirq.NamedQubit(f"{reg.name}")] + if reg.total_bits() == 1 + else cirq.NamedQubit.range(reg.total_bits(), prefix=reg.name), + dtype=object, + ) + + return {reg.name: _qubits_for_reg(reg) for reg in registers} + + class Registers: """An ordered collection of `cirq_ft.Register`. @@ -67,9 +129,6 @@ def __init__(self, registers: Iterable[Register]): def __repr__(self): return f'cirq_ft.Registers({self._registers})' - def total_bits(self) -> int: - return sum(reg.total_bits() for reg in self) - @classmethod def build(cls, **registers: Union[int, Tuple[int, ...]]) -> 'Registers': return cls(Register(name=k, shape=v) for k, v in registers.items()) @@ -105,54 +164,6 @@ def __iter__(self): def __len__(self) -> int: return len(self._registers) - def split_qubits( - self, qubits: Sequence[cirq.Qid] - ) -> Dict[str, NDArray[cirq.Qid]]: # type: ignore[type-var] - qubit_regs = {} - base = 0 - for reg in self: - qubit_regs[reg.name] = np.array(qubits[base : base + reg.total_bits()]).reshape( - reg.shape - ) - base += reg.total_bits() - return qubit_regs - - def merge_qubits( - self, **qubit_regs: Union[cirq.Qid, Sequence[cirq.Qid], NDArray[cirq.Qid]] - ) -> List[cirq.Qid]: - ret: List[cirq.Qid] = [] - for reg in self: - assert ( - reg.name in qubit_regs - ), f"All qubit registers must be present. {reg.name} not in qubit_regs" - qubits = qubit_regs[reg.name] - qubits = np.array([qubits] if isinstance(qubits, cirq.Qid) else qubits) - assert ( - qubits.shape == reg.shape - ), f'{reg.name} register must of shape {reg.shape} but is of shape {qubits.shape}' - ret += qubits.flatten().tolist() - return ret - - def get_named_qubits(self) -> Dict[str, NDArray[cirq.Qid]]: - def _qubit_array(reg: Register): - qubits = np.empty(reg.shape, dtype=object) - for ii in reg.all_idxs(): - qubits[ii] = cirq.NamedQubit(f'{reg.name}[{", ".join(str(i) for i in ii)}]') - return qubits - - def _qubits_for_reg(reg: Register): - if len(reg.shape) > 1: - return _qubit_array(reg) - - return np.array( - [cirq.NamedQubit(f"{reg.name}")] - if reg.total_bits() == 1 - else cirq.NamedQubit.range(reg.total_bits(), prefix=reg.name), - dtype=object, - ) - - return {reg.name: _qubits_for_reg(reg) for reg in self._registers} - def __eq__(self, other) -> bool: return self._registers == other._registers @@ -166,113 +177,65 @@ class SelectionRegister(Register): `SelectionRegister` extends the `Register` class to store the iteration length corresponding to that register along with its size. - """ - - iteration_length: int = attr.field() - - @iteration_length.default - def _default_iteration_length(self): - return 2 ** self.shape[0] - - @iteration_length.validator - def validate_iteration_length(self, attribute, value): - if len(self.shape) != 1: - raise ValueError(f'Selection register {self.name} should be flat. Found {self.shape=}') - if not (0 <= value <= 2 ** self.shape[0]): - raise ValueError(f'iteration length must be in range [0, 2^{self.shape[0]}]') - - def __repr__(self) -> str: - return ( - f'cirq_ft.SelectionRegister(' - f'name="{self.name}", ' - f'shape={self.shape}, ' - f'iteration_length={self.iteration_length})' - ) - - -class SelectionRegisters(Registers): - """Registers used to represent SELECT registers for various LCU methods. LCU methods often make use of coherent for-loops via UnaryIteration, iterating over a range - of values stored as a superposition over the `SELECT` register. The `SelectionRegisters` class - is used to represent such SELECT registers. In particular, it provides two additional features - on top of the regular `Registers` class: - - - For each selection register, we store the iteration length corresponding to that register - along with its size. - - We provide a default way of "flattening out" a composite index represented by a tuple of - values stored in multiple input selection registers to a single integer that can be used - to index a flat target register. - """ - - def __init__(self, registers: Iterable[SelectionRegister]): - super().__init__(registers) - self.iteration_lengths = tuple([reg.iteration_length for reg in registers]) - self._suffix_prod = np.multiply.accumulate(self.iteration_lengths[::-1])[::-1] - self._suffix_prod = np.append(self._suffix_prod, [1]) + of values stored as a superposition over the `SELECT` register. Such (nested) coherent + for-loops can be represented using a `Tuple[SelectionRegister, ...]` where the i'th entry + stores the bitsize and iteration length of i'th nested for-loop. - def to_flat_idx(self, *selection_vals: int) -> int: - """Flattens a composite index represented by a Tuple[int, ...] to a single output integer. - - For example: + One useful feature when processing such nested for-loops is to flatten out a composite index, + represented by a tuple of indices (i, j, ...), one for each selection register into a single + integer that can be used to index a flat target register. An example of such a mapping + function is described in Eq.45 of https://arxiv.org/abs/1805.03662. A general version of this + mapping function can be implemented using `numpy.ravel_multi_index` and `numpy.unravel_index`. + For example: 1) We can flatten a 2D for-loop as follows + >>> import numpy as np >>> N, M = 10, 20 >>> flat_indices = set() >>> for x in range(N): ... for y in range(M): ... flat_idx = x * M + y + ... assert np.ravel_multi_index((x, y), (N, M)) == flat_idx + ... assert np.unravel_index(flat_idx, (N, M)) == (x, y) ... flat_indices.add(flat_idx) >>> assert len(flat_indices) == N * M 2) Similarly, we can flatten a 3D for-loop as follows + >>> import numpy as np >>> N, M, L = 10, 20, 30 >>> flat_indices = set() >>> for x in range(N): ... for y in range(M): ... for z in range(L): ... flat_idx = x * M * L + y * L + z + ... assert np.ravel_multi_index((x, y, z), (N, M, L)) == flat_idx + ... assert np.unravel_index(flat_idx, (N, M, L)) == (x, y, z) ... flat_indices.add(flat_idx) >>> assert len(flat_indices) == N * M * L + """ - This is a general version of the mapping function described in Eq.45 of - https://arxiv.org/abs/1805.03662 - """ - assert len(selection_vals) == len(self) - return sum(v * self._suffix_prod[i + 1] for i, v in enumerate(selection_vals)) - - @property - def total_iteration_size(self) -> int: - return int(np.product(self.iteration_lengths)) - - @classmethod - def build(cls, **registers: Union[int, Tuple[int, ...]]) -> 'SelectionRegisters': - return cls(SelectionRegister(name=k, shape=v) for k, v in registers.items()) - - @overload - def __getitem__(self, key: int) -> SelectionRegister: - pass - - @overload - def __getitem__(self, key: str) -> SelectionRegister: - pass + iteration_length: int = attr.field() - @overload - def __getitem__(self, key: slice) -> 'SelectionRegisters': - pass + @iteration_length.default + def _default_iteration_length(self): + return 2 ** self.shape[0] - def __getitem__(self, key): - if isinstance(key, slice): - return SelectionRegisters(self._registers[key]) - elif isinstance(key, int): - return self._registers[key] - elif isinstance(key, str): - return self._register_dict[key] - else: - raise IndexError(f"key {key} must be of the type str/int/slice.") + @iteration_length.validator + def validate_iteration_length(self, attribute, value): + if len(self.shape) != 1: + raise ValueError(f'Selection register {self.name} should be flat. Found {self.shape=}') + if not (0 <= value <= 2 ** self.shape[0]): + raise ValueError(f'iteration length must be in range [0, 2^{self.shape[0]}]') def __repr__(self) -> str: - return f'cirq_ft.SelectionRegisters({self._registers})' + return ( + f'cirq_ft.SelectionRegister(' + f'name="{self.name}", ' + f'shape={self.shape}, ' + f'iteration_length={self.iteration_length})' + ) class GateWithRegisters(cirq.Gate, metaclass=abc.ABCMeta): @@ -329,7 +292,7 @@ def registers(self) -> Registers: ... def _num_qubits_(self) -> int: - return self.registers.total_bits() + return total_bits(self.registers) def decompose_from_registers( self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] @@ -339,7 +302,7 @@ def decompose_from_registers( def _decompose_with_context_( self, qubits: Sequence[cirq.Qid], context: Optional[cirq.DecompositionContext] = None ) -> cirq.OP_TREE: - qubit_regs = self.registers.split_qubits(qubits) + qubit_regs = split_qubits(self.registers, qubits) if context is None: context = cirq.DecompositionContext(cirq.ops.SimpleQubitManager()) return self.decompose_from_registers(context=context, **qubit_regs) @@ -350,7 +313,7 @@ def _decompose_(self, qubits: Sequence[cirq.Qid]) -> cirq.OP_TREE: def on_registers( self, **qubit_regs: Union[cirq.Qid, Sequence[cirq.Qid], NDArray[cirq.Qid]] ) -> cirq.Operation: - return self.on(*self.registers.merge_qubits(**qubit_regs)) + return self.on(*merge_qubits(self.registers, **qubit_regs)) def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo: """Default diagram info that uses register names to name the boxes in multi-qubit gates. diff --git a/cirq-ft/cirq_ft/infra/gate_with_registers_test.py b/cirq-ft/cirq_ft/infra/gate_with_registers_test.py index a3442eb0554..7560cb7a357 100644 --- a/cirq-ft/cirq_ft/infra/gate_with_registers_test.py +++ b/cirq-ft/cirq_ft/infra/gate_with_registers_test.py @@ -17,6 +17,7 @@ import numpy as np import pytest from cirq_ft.infra.jupyter_tools import execute_notebook +from cirq_ft.infra import split_qubits, merge_qubits, get_named_qubits def test_register(): @@ -50,13 +51,20 @@ def test_registers(): assert list(regs) == [r1, r2, r3] qubits = cirq.LineQubit.range(8) - qregs = regs.split_qubits(qubits) + qregs = split_qubits(regs, qubits) assert qregs["r1"].tolist() == cirq.LineQubit.range(5) assert qregs["r2"].tolist() == cirq.LineQubit.range(5, 5 + 2) assert qregs["r3"].tolist() == [cirq.LineQubit(7)] qubits = qubits[::-1] - merged_qregs = regs.merge_qubits(r1=qubits[:5], r2=qubits[5:7], r3=qubits[-1]) + + with pytest.raises(ValueError, match="qubit registers must be present"): + _ = merge_qubits(regs, r1=qubits[:5], r2=qubits[5:7], r4=qubits[-1]) + + with pytest.raises(ValueError, match="register must of shape"): + _ = merge_qubits(regs, r1=qubits[:4], r2=qubits[5:7], r3=qubits[-1]) + + merged_qregs = merge_qubits(regs, r1=qubits[:5], r2=qubits[5:7], r3=qubits[-1]) assert merged_qregs == qubits expected_named_qubits = { @@ -65,7 +73,7 @@ def test_registers(): "r3": [cirq.NamedQubit("r3")], } - named_qregs = regs.get_named_qubits() + named_qregs = get_named_qubits(regs) for reg_name in expected_named_qubits: assert np.array_equal(named_qregs[reg_name], expected_named_qubits[reg_name]) @@ -73,7 +81,7 @@ def test_registers(): # initial registers. for reg_order in [[r1, r2, r3], [r2, r3, r1]]: flat_named_qubits = [ - q for v in cirq_ft.Registers(reg_order).get_named_qubits().values() for q in v + q for v in get_named_qubits(cirq_ft.Registers(reg_order)).values() for q in v ] expected_qubits = [q for r in reg_order for q in expected_named_qubits[r.name]] assert flat_named_qubits == expected_qubits @@ -81,15 +89,13 @@ def test_registers(): @pytest.mark.parametrize('n, N, m, M', [(4, 10, 5, 19), (4, 16, 5, 32)]) def test_selection_registers_indexing(n, N, m, M): - reg = cirq_ft.SelectionRegisters( - [cirq_ft.SelectionRegister('x', n, N), cirq_ft.SelectionRegister('y', m, M)] - ) - assert reg.iteration_lengths == (N, M) - for x in range(N): - for y in range(M): - assert reg.to_flat_idx(x, y) == x * M + y + regs = [cirq_ft.SelectionRegister('x', n, N), cirq_ft.SelectionRegister('y', m, M)] + for x in range(regs[0].iteration_length): + for y in range(regs[1].iteration_length): + assert np.ravel_multi_index((x, y), (N, M)) == x * M + y + assert np.unravel_index(x * M + y, (N, M)) == (x, y) - assert reg.total_iteration_size == N * M + assert np.product(tuple(reg.iteration_length for reg in regs)) == N * M def test_selection_registers_consistent(): @@ -99,7 +105,7 @@ def test_selection_registers_consistent(): with pytest.raises(ValueError, match="should be flat"): _ = cirq_ft.SelectionRegister('a', (3, 5), 5) - selection_reg = cirq_ft.SelectionRegisters( + selection_reg = cirq_ft.Registers( [ cirq_ft.SelectionRegister('n', shape=3, iteration_length=5), cirq_ft.SelectionRegister('m', shape=4, iteration_length=12), @@ -108,7 +114,7 @@ def test_selection_registers_consistent(): assert selection_reg[0] == cirq_ft.SelectionRegister('n', 3, 5) assert selection_reg['n'] == cirq_ft.SelectionRegister('n', 3, 5) assert selection_reg[1] == cirq_ft.SelectionRegister('m', 4, 12) - assert selection_reg[:1] == cirq_ft.SelectionRegisters([cirq_ft.SelectionRegister('n', 3, 5)]) + assert selection_reg[:1] == cirq_ft.Registers([cirq_ft.SelectionRegister('n', 3, 5)]) def test_registers_getitem_raises(): @@ -116,9 +122,7 @@ def test_registers_getitem_raises(): with pytest.raises(IndexError, match="must be of the type"): _ = g[2.5] - selection_reg = cirq_ft.SelectionRegisters( - [cirq_ft.SelectionRegister('n', shape=3, iteration_length=5)] - ) + selection_reg = cirq_ft.Registers([cirq_ft.SelectionRegister('n', shape=3, iteration_length=5)]) with pytest.raises(IndexError, match='must be of the type'): _ = selection_reg[2.5] diff --git a/cirq-ft/cirq_ft/infra/jupyter_tools.py b/cirq-ft/cirq_ft/infra/jupyter_tools.py index 148c29baca8..a9ae4817ef7 100644 --- a/cirq-ft/cirq_ft/infra/jupyter_tools.py +++ b/cirq-ft/cirq_ft/infra/jupyter_tools.py @@ -21,7 +21,7 @@ import IPython.display import ipywidgets import nbformat -from cirq_ft.infra import gate_with_registers, t_complexity_protocol +from cirq_ft.infra import gate_with_registers, t_complexity_protocol, get_named_qubits, merge_qubits from nbconvert.preprocessors import ExecutePreprocessor @@ -83,7 +83,7 @@ def svg_circuit( if registers is not None: qubit_order = cirq.QubitOrder.explicit( - registers.merge_qubits(**registers.get_named_qubits()), fallback=cirq.QubitOrder.DEFAULT + merge_qubits(registers, **get_named_qubits(registers)), fallback=cirq.QubitOrder.DEFAULT ) else: qubit_order = cirq.QubitOrder.DEFAULT diff --git a/cirq-ft/cirq_ft/infra/t_complexity.ipynb b/cirq-ft/cirq_ft/infra/t_complexity.ipynb index 3986abfaa5a..3a4c7c4596a 100644 --- a/cirq-ft/cirq_ft/infra/t_complexity.ipynb +++ b/cirq-ft/cirq_ft/infra/t_complexity.ipynb @@ -40,7 +40,7 @@ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", - "from cirq_ft import And, t_complexity" + "from cirq_ft import And, t_complexity, infra" ] }, { @@ -61,7 +61,7 @@ "# And of two qubits\n", "gate = And() # create an And gate\n", "# create an operation\n", - "operation = gate.on_registers(**gate.registers.get_named_qubits()) \n", + "operation = gate.on_registers(**infra.get_named_qubits(gate.registers))\n", "# this operation doesn't directly support TComplexity but it's decomposable and its components are simple.\n", "print(t_complexity(operation))" ] @@ -82,7 +82,7 @@ "outputs": [], "source": [ "gate = And() ** -1 # adjoint of And\n", - "operation = gate.on_registers(**gate.registers.get_named_qubits()) \n", + "operation = gate.on_registers(**infra.get_named_qubits(gate.registers))\n", "# the deomposition is H, measure, CZ, and Reset\n", "print(t_complexity(operation))" ] @@ -104,7 +104,7 @@ "source": [ "n = 5\n", "gate = And((1, )*n)\n", - "operation = gate.on_registers(**gate.registers.get_named_qubits()) \n", + "operation = gate.on_registers(**infra.get_named_qubits(gate.registers))\n", "print(t_complexity(operation))" ] }, @@ -122,7 +122,7 @@ " for n in range(2, n_max + 2):\n", " n_controls.append(n)\n", " gate = And(cv=(1, )*n)\n", - " op = gate.on_registers(**gate.registers.get_named_qubits()) \n", + " op = gate.on_registers(**infra.get_named_qubits(gate.registers))\n", " c = t_complexity(op)\n", " t_count.append(c.t)\n", " return n_controls, t_count" @@ -171,4 +171,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/cirq-ft/cirq_ft/infra/t_complexity_protocol_test.py b/cirq-ft/cirq_ft/infra/t_complexity_protocol_test.py index 851e5907119..f28f3fc5e6a 100644 --- a/cirq-ft/cirq_ft/infra/t_complexity_protocol_test.py +++ b/cirq-ft/cirq_ft/infra/t_complexity_protocol_test.py @@ -15,6 +15,7 @@ import cirq import cirq_ft import pytest +from cirq_ft import infra from cirq_ft.infra.jupyter_tools import execute_notebook @@ -108,11 +109,11 @@ def test_operations(): assert cirq_ft.t_complexity(cirq.T(q)) == cirq_ft.TComplexity(t=1) gate = cirq_ft.And() - op = gate.on_registers(**gate.registers.get_named_qubits()) + op = gate.on_registers(**infra.get_named_qubits(gate.registers)) assert cirq_ft.t_complexity(op) == cirq_ft.TComplexity(t=4, clifford=9) gate = cirq_ft.And() ** -1 - op = gate.on_registers(**gate.registers.get_named_qubits()) + op = gate.on_registers(**infra.get_named_qubits(gate.registers)) assert cirq_ft.t_complexity(op) == cirq_ft.TComplexity(clifford=4) diff --git a/cirq-ft/cirq_ft/infra/testing.py b/cirq-ft/cirq_ft/infra/testing.py index 6ceb21d5c2a..31802d5300d 100644 --- a/cirq-ft/cirq_ft/infra/testing.py +++ b/cirq-ft/cirq_ft/infra/testing.py @@ -18,7 +18,7 @@ import cirq import numpy as np from cirq._compat import cached_property -from cirq_ft.infra import gate_with_registers, t_complexity_protocol +from cirq_ft.infra import gate_with_registers, t_complexity_protocol, merge_qubits, get_named_qubits from cirq_ft.infra.decompose_protocol import _decompose_once_considering_known_decomposition @@ -44,12 +44,12 @@ def r(self) -> gate_with_registers.Registers: @cached_property def quregs(self) -> Dict[str, NDArray[cirq.Qid]]: # type: ignore[type-var] """A dictionary of named qubits appropriate for the registers for the gate.""" - return self.r.get_named_qubits() + return get_named_qubits(self.r) @cached_property def all_qubits(self) -> List[cirq.Qid]: """All qubits in Register order.""" - merged_qubits = self.r.merge_qubits(**self.quregs) + merged_qubits = merge_qubits(self.r, **self.quregs) decomposed_qubits = self.decomposed_circuit.all_qubits() return merged_qubits + sorted(decomposed_qubits - frozenset(merged_qubits)) diff --git a/cirq-ft/cirq_ft/infra/type_convertors.py b/cirq-ft/cirq_ft/infra/type_convertors.py deleted file mode 100644 index fa182596df7..00000000000 --- a/cirq-ft/cirq_ft/infra/type_convertors.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright 2023 The Cirq Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Sequence, Tuple, Union - - -def to_tuple(x: Union[int, Sequence[int]]) -> Tuple[int, ...]: - """Mypy type-safe convertor to be used in an attrs field.""" - return (x,) if isinstance(x, int) else tuple(x) diff --git a/cirq-ft/cirq_ft/infra/type_convertors_test.py b/cirq-ft/cirq_ft/infra/type_convertors_test.py deleted file mode 100644 index ce1f780e6af..00000000000 --- a/cirq-ft/cirq_ft/infra/type_convertors_test.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright 2023 The Cirq Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import cirq_ft - - -def test_to_tuple(): - assert cirq_ft.infra.to_tuple([1, 2]) == (1, 2) - assert cirq_ft.infra.to_tuple((1, 2)) == (1, 2) - assert cirq_ft.infra.to_tuple(1) == (1,) From deedb45377d610a9df6ea67207413a4f1ea57d31 Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Fri, 8 Sep 2023 22:18:32 +0000 Subject: [PATCH 04/19] Changed FakeQuantumRunStream to support arbitrary response and exception timing (#6253) * Changed FakeQuantumRunStream to support arbitrary response and exception timing * Fixed type errors and task got bad yield errors * Fixed more type errors and replaced anext() calls * Addressed maffoo's comments --- .../cirq_google/engine/stream_manager_test.py | 534 +++++++++--------- 1 file changed, 282 insertions(+), 252 deletions(-) diff --git a/cirq-google/cirq_google/engine/stream_manager_test.py b/cirq-google/cirq_google/engine/stream_manager_test.py index 7b56dcb8bb3..3732547cdca 100644 --- a/cirq-google/cirq_google/engine/stream_manager_test.py +++ b/cirq-google/cirq_google/engine/stream_manager_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import AsyncIterable, AsyncIterator, Awaitable, List, Union +from typing import AsyncIterable, AsyncIterator, Awaitable, List, Sequence, Union import asyncio import concurrent from unittest import mock @@ -21,6 +21,7 @@ import pytest import google.api_core.exceptions as google_exceptions +from cirq_google.engine.asyncio_executor import AsyncioExecutor from cirq_google.engine.stream_manager import ( _get_retry_request_or_raise, ProgramAlreadyExistsError, @@ -49,63 +50,114 @@ # StreamManager test suite constants REQUEST_PROJECT_NAME = 'projects/proj' REQUEST_PROGRAM = quantum.QuantumProgram(name='projects/proj/programs/prog') -REQUEST_JOB = quantum.QuantumJob(name='projects/proj/programs/prog/jobs/job0') +REQUEST_JOB0 = quantum.QuantumJob(name='projects/proj/programs/prog/jobs/job0') +REQUEST_JOB1 = quantum.QuantumJob(name='projects/proj/programs/prog/jobs/job1') -def setup_fake_quantum_run_stream_client(client_constructor, responses_and_exceptions): - grpc_client = FakeQuantumRunStream(responses_and_exceptions) - client_constructor.return_value = grpc_client - return grpc_client +def setup_client(client_constructor): + fake_client = FakeQuantumRunStream() + client_constructor.return_value = fake_client + return fake_client + + +def setup(client_constructor): + fake_client = setup_client(client_constructor) + return fake_client, StreamManager(fake_client) class FakeQuantumRunStream: """A fake Quantum Engine client which supports QuantumRunStream and CancelQuantumJob.""" - def __init__( - self, responses_and_exceptions: List[Union[quantum.QuantumRunStreamResponse, BaseException]] - ): - self.stream_requests: List[quantum.QuantumRunStreamRequest] = [] - self.cancel_requests: List[quantum.CancelQuantumJobRequest] = [] - self.responses_and_exceptions = responses_and_exceptions + def __init__(self): + self.all_stream_requests: List[quantum.QuantumRunStreamRequest] = [] + self.all_cancel_requests: List[quantum.CancelQuantumJobRequest] = [] + self._executor = AsyncioExecutor.instance() + self._request_buffer = duet.AsyncCollector[quantum.QuantumRunStreamRequest]() + # asyncio.Queue needs to be initialized inside the asyncio thread because all callers need + # to use the same event loop. + self._responses_and_exceptions_future = duet.AwaitableFuture[asyncio.Queue]() async def quantum_run_stream( self, requests: AsyncIterator[quantum.QuantumRunStreamRequest], **kwargs ) -> Awaitable[AsyncIterable[quantum.QuantumRunStreamResponse]]: """Fakes the QuantumRunStream RPC. - Expects the number of requests to be the same as len(self.responses_and_exceptions). - - For every request, a response or exception is popped from `self.responses_and_exceptions`. - Before the next request: - * If it is a response, it is sent back through the stream. - * If it is an exception, the exception is raised. + Once a request is received, it is appended to `stream_requests`, and the test calling + `wait_for_requests()` is notified. - This fake does not support out-of-order responses. + The response is sent when a test calls `reply()` with a `QuantumRunStreamResponse`. If a + test calls `reply()` with an exception, it is raised here to the `quantum_run_stream()` + caller. - No responses are ever made if `self.responses_and_exceptions` is empty. + This is called from the asyncio thread. """ + responses_and_exceptions: asyncio.Queue = asyncio.Queue() + self._responses_and_exceptions_future.try_set_result(responses_and_exceptions) - async def run_async_iterator(): + async def read_requests(): async for request in requests: - self.stream_requests.append(request) - - if not self.responses_and_exceptions: - while True: - await asyncio.sleep(1) - - response_or_exception = self.responses_and_exceptions.pop(0) - if isinstance(response_or_exception, BaseException): + self.all_stream_requests.append(request) + self._request_buffer.add(request) + + async def response_iterator(): + asyncio.create_task(read_requests()) + while True: + response_or_exception = await responses_and_exceptions.get() + if isinstance(response_or_exception, quantum.QuantumRunStreamResponse): + yield response_or_exception + else: # isinstance(response_or_exception, BaseException) + self._responses_and_exceptions_future = duet.AwaitableFuture[asyncio.Queue]() raise response_or_exception - response_or_exception.message_id = request.message_id - yield response_or_exception await asyncio.sleep(0) - return run_async_iterator() + return response_iterator() async def cancel_quantum_job(self, request: quantum.CancelQuantumJobRequest) -> None: - self.cancel_requests.append(request) + """Records the cancellation in `cancel_requests`. + + This is called from the asyncio thread. + """ + self.all_cancel_requests.append(request) await asyncio.sleep(0) + async def wait_for_requests(self, num_requests=1) -> Sequence[quantum.QuantumRunStreamRequest]: + """Wait til `num_requests` number of requests are received via `quantum_run_stream()`. + + This must be called from the duet thread. + + Returns: + The received requests. + """ + requests = [] + for _ in range(num_requests): + requests.append(await self._request_buffer.__anext__()) + return requests + + async def reply( + self, response_or_exception: Union[quantum.QuantumRunStreamResponse, BaseException] + ): + """Sends a response or raises an exception to the `quantum_run_stream()` caller. + + If input response is missing `message_id`, it is defaulted to the `message_id` of the most + recent request. This is to support the most common use case of responding immediately after + a request. + + Assumes that at least one request must have been submitted to the StreamManager. + + This must be called from the duet thread. + """ + responses_and_exceptions = await self._responses_and_exceptions_future + if ( + isinstance(response_or_exception, quantum.QuantumRunStreamResponse) + and not response_or_exception.message_id + ): + response_or_exception.message_id = self.all_stream_requests[-1].message_id + + async def send(): + await responses_and_exceptions.put(response_or_exception) + + await self._executor.submit(send) + class TestResponseDemux: @pytest.fixture @@ -207,49 +259,38 @@ async def test_publish_exception_after_publishing_response_does_not_change_futur class TestStreamManager: @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_submit_expects_result_response(self, client_constructor): + # Arrange + expected_result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') + fake_client, manager = setup(client_constructor) + async def test(): async with duet.timeout_scope(5): - # Arrange - expected_result = quantum.QuantumResult( - parent='projects/proj/programs/prog/jobs/job0' - ) - mock_responses = [quantum.QuantumRunStreamResponse(result=expected_result)] - fake_client = setup_fake_quantum_run_stream_client( - client_constructor, responses_and_exceptions=mock_responses - ) - manager = StreamManager(fake_client) - # Act - actual_result = await manager.submit( - REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB + actual_result_future = manager.submit( + REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0 ) + await fake_client.wait_for_requests() + await fake_client.reply(quantum.QuantumRunStreamResponse(result=expected_result)) + actual_result = await actual_result_future manager.stop() # Assert assert actual_result == expected_result - assert len(fake_client.stream_requests) == 1 + assert len(fake_client.all_stream_requests) == 1 # assert that the first request is a CreateQuantumProgramAndJobRequest. - assert 'create_quantum_program_and_job' in fake_client.stream_requests[0] + assert 'create_quantum_program_and_job' in fake_client.all_stream_requests[0] duet.run(test) @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_submit_program_without_name_raises(self, client_constructor): + _, manager = setup(client_constructor) + async def test(): async with duet.timeout_scope(5): - # Arrange - expected_result = quantum.QuantumResult( - parent='projects/proj/programs/prog/jobs/job0' - ) - mock_responses = [quantum.QuantumRunStreamResponse(result=expected_result)] - fake_client = setup_fake_quantum_run_stream_client( - client_constructor, responses_and_exceptions=mock_responses - ) - manager = StreamManager(fake_client) - with pytest.raises(ValueError, match='Program name must be set'): await manager.submit( - REQUEST_PROJECT_NAME, quantum.QuantumProgram(), REQUEST_JOB + REQUEST_PROJECT_NAME, quantum.QuantumProgram(), REQUEST_JOB0 ) manager.stop() @@ -257,20 +298,17 @@ async def test(): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_submit_cancel_future_expects_engine_cancellation_rpc_call(self, client_constructor): + fake_client, manager = setup(client_constructor) + async def test(): async with duet.timeout_scope(5): - fake_client = setup_fake_quantum_run_stream_client( - client_constructor, responses_and_exceptions=[] - ) - manager = StreamManager(fake_client) - - result_future = manager.submit(REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB) + result_future = manager.submit(REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0) result_future.cancel() await duet.sleep(1) # Let cancellation complete asynchronously manager.stop() - assert len(fake_client.cancel_requests) == 1 - assert fake_client.cancel_requests[0] == quantum.CancelQuantumJobRequest( + assert len(fake_client.all_cancel_requests) == 1 + assert fake_client.all_cancel_requests[0] == quantum.CancelQuantumJobRequest( name='projects/proj/programs/prog/jobs/job0' ) @@ -280,31 +318,28 @@ async def test(): def test_submit_stream_broken_twice_expects_retry_with_get_quantum_result_twice( self, client_constructor ): + expected_result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') + fake_client, manager = setup(client_constructor) + async def test(): async with duet.timeout_scope(5): - expected_result = quantum.QuantumResult( - parent='projects/proj/programs/prog/jobs/job0' - ) - mock_responses_and_exceptions = [ - google_exceptions.ServiceUnavailable('unavailable'), - google_exceptions.ServiceUnavailable('unavailable'), - quantum.QuantumRunStreamResponse(result=expected_result), - ] - fake_client = setup_fake_quantum_run_stream_client( - client_constructor, responses_and_exceptions=mock_responses_and_exceptions - ) - manager = StreamManager(fake_client) - - actual_result = await manager.submit( - REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB + actual_result_future = manager.submit( + REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0 ) + await fake_client.wait_for_requests() + await fake_client.reply(google_exceptions.ServiceUnavailable('unavailable')) + await fake_client.wait_for_requests() + await fake_client.reply(google_exceptions.ServiceUnavailable('unavailable')) + await fake_client.wait_for_requests() + await fake_client.reply(quantum.QuantumRunStreamResponse(result=expected_result)) + actual_result = await actual_result_future manager.stop() assert actual_result == expected_result - assert len(fake_client.stream_requests) == 3 - assert 'create_quantum_program_and_job' in fake_client.stream_requests[0] - assert 'get_quantum_result' in fake_client.stream_requests[1] - assert 'get_quantum_result' in fake_client.stream_requests[2] + assert len(fake_client.all_stream_requests) == 3 + assert 'create_quantum_program_and_job' in fake_client.all_stream_requests[0] + assert 'get_quantum_result' in fake_client.all_stream_requests[1] + assert 'get_quantum_result' in fake_client.all_stream_requests[2] duet.run(test) @@ -319,25 +354,24 @@ async def test(): def test_submit_with_retryable_stream_breakage_expects_get_result_request( self, client_constructor, error ): + expected_result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') + fake_client, manager = setup(client_constructor) + async def test(): async with duet.timeout_scope(5): - mock_responses = [ - error, - quantum.QuantumRunStreamResponse( - result=quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') - ), - ] - fake_client = setup_fake_quantum_run_stream_client( - client_constructor, responses_and_exceptions=mock_responses + actual_result_future = manager.submit( + REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0 ) - manager = StreamManager(fake_client) - - await manager.submit(REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB) + await fake_client.wait_for_requests() + await fake_client.reply(error) + await fake_client.wait_for_requests() + await fake_client.reply(quantum.QuantumRunStreamResponse(result=expected_result)) + await actual_result_future manager.stop() - assert len(fake_client.stream_requests) == 2 - assert 'create_quantum_program_and_job' in fake_client.stream_requests[0] - assert 'get_quantum_result' in fake_client.stream_requests[1] + assert len(fake_client.all_stream_requests) == 2 + assert 'create_quantum_program_and_job' in fake_client.all_stream_requests[0] + assert 'get_quantum_result' in fake_client.all_stream_requests[1] duet.run(test) @@ -360,80 +394,73 @@ async def test(): def test_submit_with_non_retryable_stream_breakage_raises_error( self, client_constructor, error ): + fake_client, manager = setup(client_constructor) + async def test(): async with duet.timeout_scope(5): - mock_responses = [ - error, - quantum.QuantumRunStreamResponse( - result=quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') - ), - ] - fake_client = setup_fake_quantum_run_stream_client( - client_constructor, responses_and_exceptions=mock_responses + actual_result_future = manager.submit( + REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0 ) - manager = StreamManager(fake_client) - + await fake_client.wait_for_requests() + await fake_client.reply(error) with pytest.raises(type(error)): - await manager.submit(REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB) + await actual_result_future manager.stop() - assert len(fake_client.stream_requests) == 1 - assert 'create_quantum_program_and_job' in fake_client.stream_requests[0] + assert len(fake_client.all_stream_requests) == 1 + assert 'create_quantum_program_and_job' in fake_client.all_stream_requests[0] duet.run(test) @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_submit_expects_job_response(self, client_constructor): + expected_job = quantum.QuantumJob(name='projects/proj/programs/prog/jobs/job0') + fake_client, manager = setup(client_constructor) + async def test(): async with duet.timeout_scope(5): - expected_job = quantum.QuantumJob(name='projects/proj/programs/prog/jobs/job0') - mock_responses = [quantum.QuantumRunStreamResponse(job=expected_job)] - fake_client = setup_fake_quantum_run_stream_client( - client_constructor, responses_and_exceptions=mock_responses - ) - manager = StreamManager(fake_client) - - actual_job = await manager.submit( - REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB + actual_job_future = manager.submit( + REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0 ) + await fake_client.wait_for_requests() + await fake_client.reply(quantum.QuantumRunStreamResponse(job=expected_job)) + actual_job = await actual_job_future manager.stop() assert actual_job == expected_job - assert len(fake_client.stream_requests) == 1 - # assert that the first request is a CreateQuantumProgramAndJobRequest. - assert 'create_quantum_program_and_job' in fake_client.stream_requests[0] + assert len(fake_client.all_stream_requests) == 1 + assert 'create_quantum_program_and_job' in fake_client.all_stream_requests[0] duet.run(test) @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_submit_job_does_not_exist_expects_create_quantum_job_request(self, client_constructor): + expected_result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') + fake_client, manager = setup(client_constructor) + async def test(): async with duet.timeout_scope(5): - expected_result = quantum.QuantumResult( - parent='projects/proj/programs/prog/jobs/job0' + actual_result_future = manager.submit( + REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0 ) - mock_responses_and_exceptions = [ - google_exceptions.ServiceUnavailable('unavailable'), + await fake_client.wait_for_requests() + await fake_client.reply(google_exceptions.ServiceUnavailable('unavailable')) + await fake_client.wait_for_requests() + await fake_client.reply( quantum.QuantumRunStreamResponse( error=quantum.StreamError(code=quantum.StreamError.Code.JOB_DOES_NOT_EXIST) - ), - quantum.QuantumRunStreamResponse(result=expected_result), - ] - fake_client = setup_fake_quantum_run_stream_client( - client_constructor, responses_and_exceptions=mock_responses_and_exceptions - ) - manager = StreamManager(fake_client) - - actual_result = await manager.submit( - REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB + ) ) + await fake_client.wait_for_requests() + await fake_client.reply(quantum.QuantumRunStreamResponse(result=expected_result)) + actual_result = await actual_result_future manager.stop() assert actual_result == expected_result - assert len(fake_client.stream_requests) == 3 - assert 'create_quantum_program_and_job' in fake_client.stream_requests[0] - assert 'get_quantum_result' in fake_client.stream_requests[1] - assert 'create_quantum_job' in fake_client.stream_requests[2] + assert len(fake_client.all_stream_requests) == 3 + assert 'create_quantum_program_and_job' in fake_client.all_stream_requests[0] + assert 'get_quantum_result' in fake_client.all_stream_requests[1] + assert 'create_quantum_job' in fake_client.all_stream_requests[2] duet.run(test) @@ -441,39 +468,41 @@ async def test(): def test_submit_program_does_not_exist_expects_create_quantum_program_and_job_request( self, client_constructor ): + expected_result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') + fake_client, manager = setup(client_constructor) + async def test(): async with duet.timeout_scope(5): - expected_result = quantum.QuantumResult( - parent='projects/proj/programs/prog/jobs/job0' + actual_result_future = manager.submit( + REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0 ) - mock_responses_and_exceptions = [ - google_exceptions.ServiceUnavailable('unavailable'), + await fake_client.wait_for_requests() + await fake_client.reply(google_exceptions.ServiceUnavailable('unavailable')) + await fake_client.wait_for_requests() + await fake_client.reply( quantum.QuantumRunStreamResponse( error=quantum.StreamError(code=quantum.StreamError.Code.JOB_DOES_NOT_EXIST) - ), + ) + ) + await fake_client.wait_for_requests() + await fake_client.reply( quantum.QuantumRunStreamResponse( error=quantum.StreamError( code=quantum.StreamError.Code.PROGRAM_DOES_NOT_EXIST ) - ), - quantum.QuantumRunStreamResponse(result=expected_result), - ] - fake_client = setup_fake_quantum_run_stream_client( - client_constructor, responses_and_exceptions=mock_responses_and_exceptions - ) - manager = StreamManager(fake_client) - - actual_result = await manager.submit( - REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB + ) ) + await fake_client.wait_for_requests() + await fake_client.reply(quantum.QuantumRunStreamResponse(result=expected_result)) + actual_result = await actual_result_future manager.stop() assert actual_result == expected_result - assert len(fake_client.stream_requests) == 4 - assert 'create_quantum_program_and_job' in fake_client.stream_requests[0] - assert 'get_quantum_result' in fake_client.stream_requests[1] - assert 'create_quantum_job' in fake_client.stream_requests[2] - assert 'create_quantum_program_and_job' in fake_client.stream_requests[3] + assert len(fake_client.all_stream_requests) == 4 + assert 'create_quantum_program_and_job' in fake_client.all_stream_requests[0] + assert 'get_quantum_result' in fake_client.all_stream_requests[1] + assert 'create_quantum_job' in fake_client.all_stream_requests[2] + assert 'create_quantum_program_and_job' in fake_client.all_stream_requests[3] duet.run(test) @@ -481,124 +510,129 @@ async def test(): def test_submit_program_already_exists_expects_program_already_exists_error( self, client_constructor ): + fake_client, manager = setup(client_constructor) + async def test(): async with duet.timeout_scope(5): - mock_responses_and_exceptions = [ + actual_result_future = manager.submit( + REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0 + ) + await fake_client.wait_for_requests() + await fake_client.reply( quantum.QuantumRunStreamResponse( error=quantum.StreamError( code=quantum.StreamError.Code.PROGRAM_ALREADY_EXISTS ) ) - ] - fake_client = setup_fake_quantum_run_stream_client( - client_constructor, responses_and_exceptions=mock_responses_and_exceptions ) - manager = StreamManager(fake_client) - with pytest.raises(ProgramAlreadyExistsError): - await manager.submit(REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB) + await actual_result_future manager.stop() duet.run(test) @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_submit_twice_in_parallel_expect_result_responses(self, client_constructor): + expected_result0 = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') + expected_result1 = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job1') + fake_client, manager = setup(client_constructor) + async def test(): async with duet.timeout_scope(5): - request_job1 = quantum.QuantumJob(name='projects/proj/programs/prog/jobs/job1') - expected_result0 = quantum.QuantumResult( - parent='projects/proj/programs/prog/jobs/job0' + actual_result0_future = manager.submit( + REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0 ) - expected_result1 = quantum.QuantumResult( - parent='projects/proj/programs/prog/jobs/job1' + actual_result1_future = manager.submit( + REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB1 ) - mock_responses = [ - quantum.QuantumRunStreamResponse(result=expected_result0), - quantum.QuantumRunStreamResponse(result=expected_result1), - ] - fake_client = setup_fake_quantum_run_stream_client( - client_constructor, responses_and_exceptions=mock_responses + await fake_client.wait_for_requests(num_requests=2) + await fake_client.reply( + quantum.QuantumRunStreamResponse( + message_id=fake_client.all_stream_requests[0].message_id, + result=expected_result0, + ) ) - manager = StreamManager(fake_client) + await fake_client.reply( + quantum.QuantumRunStreamResponse( + message_id=fake_client.all_stream_requests[1].message_id, + result=expected_result1, + ) + ) + actual_result1 = await actual_result1_future + actual_result0 = await actual_result0_future + manager.stop() + assert actual_result0 == expected_result0 + assert actual_result1 == expected_result1 + assert len(fake_client.all_stream_requests) == 2 + assert 'create_quantum_program_and_job' in fake_client.all_stream_requests[0] + assert 'create_quantum_program_and_job' in fake_client.all_stream_requests[1] + + duet.run(test) + + @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) + def test_submit_twice_and_break_stream_expect_result_responses(self, client_constructor): + expected_result0 = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') + expected_result1 = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job1') + fake_client, manager = setup(client_constructor) + + async def test(): + async with duet.timeout_scope(5): actual_result0_future = manager.submit( - REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB + REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0 ) actual_result1_future = manager.submit( - REQUEST_PROJECT_NAME, REQUEST_PROGRAM, request_job1 + REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB1 + ) + await fake_client.wait_for_requests(num_requests=2) + await fake_client.reply(google_exceptions.ServiceUnavailable('unavailable')) + await fake_client.wait_for_requests(num_requests=2) + await fake_client.reply( + quantum.QuantumRunStreamResponse( + message_id=next( + req.message_id + for req in fake_client.all_stream_requests[2:] + if req.get_quantum_result.parent == expected_result0.parent + ), + result=expected_result0, + ) + ) + await fake_client.reply( + quantum.QuantumRunStreamResponse( + message_id=next( + req.message_id + for req in fake_client.all_stream_requests[2:] + if req.get_quantum_result.parent == expected_result1.parent + ), + result=expected_result1, + ) ) - actual_result1 = await actual_result1_future actual_result0 = await actual_result0_future + actual_result1 = await actual_result1_future manager.stop() assert actual_result0 == expected_result0 assert actual_result1 == expected_result1 - assert len(fake_client.stream_requests) == 2 - assert 'create_quantum_program_and_job' in fake_client.stream_requests[0] - assert 'create_quantum_program_and_job' in fake_client.stream_requests[1] + assert len(fake_client.all_stream_requests) == 4 + assert 'create_quantum_program_and_job' in fake_client.all_stream_requests[0] + assert 'create_quantum_program_and_job' in fake_client.all_stream_requests[1] + assert 'get_quantum_result' in fake_client.all_stream_requests[2] + assert 'get_quantum_result' in fake_client.all_stream_requests[3] duet.run(test) - # TODO(#5996) Update fake client implementation to support this test case. - # @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) - # def test_submit_twice_and_break_stream_expect_result_responses(self, client_constructor): - # async def test(): - # async with duet.timeout_scope(5): - # request_job1 = quantum.QuantumJob(name='projects/proj/programs/prog/jobs/job1') - # expected_result0 = quantum.QuantumResult( - # parent='projects/proj/programs/prog/jobs/job0' - # ) - # expected_result1 = quantum.QuantumResult( - # parent='projects/proj/programs/prog/jobs/job1' - # ) - # # TODO the current fake client doesn't have the response timing flexibility - # # required by this test. - # # Ideally, the client raises ServiceUnavailable after both initial requests are - # # sent. - # mock_responses = [ - # google_exceptions.ServiceUnavailable('unavailable'), - # google_exceptions.ServiceUnavailable('unavailable'), - # quantum.QuantumRunStreamResponse(result=expected_result0), - # quantum.QuantumRunStreamResponse(result=expected_result1), - # ] - # fake_client = setup_fake_quantum_run_stream_client( - # client_constructor, responses_and_exceptions=mock_responses - # ) - # manager = StreamManager(fake_client) - - # actual_result0_future = manager.submit( - # REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB - # ) - # actual_result1_future = manager.submit( - # REQUEST_PROJECT_NAME, REQUEST_PROGRAM, request_job1 - # ) - # actual_result1 = await actual_result1_future - # actual_result0 = await actual_result0_future - # manager.stop() - - # assert actual_result0 == expected_result0 - # assert actual_result1 == expected_result1 - # assert len(fake_client.stream_requests) == 2 - # assert 'create_quantum_program_and_job' in fake_client.stream_requests[0] - # assert 'create_quantum_program_and_job' in fake_client.stream_requests[1] - - # duet.run(test) - @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_stop_cancels_existing_sends(self, client_constructor): + fake_client, manager = setup(client_constructor) + async def test(): async with duet.timeout_scope(5): - fake_client = setup_fake_quantum_run_stream_client( - client_constructor, responses_and_exceptions=[] - ) - manager = StreamManager(fake_client) - actual_result_future = manager.submit( - REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB + REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0 ) # Wait for the manager to submit a request. If request submission runs after stop(), # it will start the manager again and the test will block waiting for a response. - await duet.sleep(1) + await fake_client.wait_for_requests() manager.stop() with pytest.raises(concurrent.futures.CancelledError): @@ -609,28 +643,24 @@ async def test(): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_stop_then_send_expects_result_response(self, client_constructor): """New requests should work after stopping the manager.""" + expected_result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') + fake_client, manager = setup(client_constructor) async def test(): async with duet.timeout_scope(5): - expected_result = quantum.QuantumResult( - parent='projects/proj/programs/prog/jobs/job0' - ) - mock_responses = [quantum.QuantumRunStreamResponse(result=expected_result)] - fake_client = setup_fake_quantum_run_stream_client( - client_constructor, responses_and_exceptions=mock_responses - ) - manager = StreamManager(fake_client) - manager.stop() - actual_result = await manager.submit( - REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB + actual_result_future = manager.submit( + REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0 ) + await fake_client.wait_for_requests() + await fake_client.reply(quantum.QuantumRunStreamResponse(result=expected_result)) + actual_result = await actual_result_future manager.stop() assert actual_result == expected_result - assert len(fake_client.stream_requests) == 1 + assert len(fake_client.all_stream_requests) == 1 # assert that the first request is a CreateQuantumProgramAndJobRequest. - assert 'create_quantum_program_and_job' in fake_client.stream_requests[0] + assert 'create_quantum_program_and_job' in fake_client.all_stream_requests[0] duet.run(test) From 6c14cfa76942e2f7421898816003d8f9a5fbe9e3 Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Mon, 11 Sep 2023 20:26:15 +0000 Subject: [PATCH 05/19] StreamManager: Add mechanism to close the request iterator (#6263) * Add a signal to stop the request iterator * Make request_queue local to asyncio coroutines * Added missing raises docstring * Addressed maffoo's comments * Addressed maffoo's nits * Fix failing stream_manager_test after merging * Fix format --- .../cirq_google/engine/stream_manager.py | 71 ++++++---- .../cirq_google/engine/stream_manager_test.py | 127 ++++++++++++++++-- 2 files changed, 164 insertions(+), 34 deletions(-) diff --git a/cirq-google/cirq_google/engine/stream_manager.py b/cirq-google/cirq_google/engine/stream_manager.py index b5bb5696eda..c45e43d81fc 100644 --- a/cirq-google/cirq_google/engine/stream_manager.py +++ b/cirq-google/cirq_google/engine/stream_manager.py @@ -109,8 +109,6 @@ class StreamManager: def __init__(self, grpc_client: quantum.QuantumEngineServiceAsyncClient): self._grpc_client = grpc_client - # TODO(#5996) Make this local to the asyncio thread. - self._request_queue: Optional[asyncio.Queue] = None # Used to determine whether the stream coroutine is actively running, and provides a way to # cancel it. self._manage_stream_loop_future: Optional[duet.AwaitableFuture[None]] = None @@ -121,6 +119,16 @@ def __init__(self, grpc_client: quantum.QuantumEngineServiceAsyncClient): # interface. self._response_demux = ResponseDemux() self._next_available_message_id = 0 + # Construct queue in AsyncioExecutor to ensure it binds to the correct event loop, since it + # is used by asyncio coroutines. + self._request_queue = self._executor.submit(self._make_request_queue).result() + + async def _make_request_queue(self) -> asyncio.Queue[Optional[quantum.QuantumRunStreamRequest]]: + """Returns a queue used to back the request iterator passed to the stream. + + If `None` is put into the queue, the request iterator will stop. + """ + return asyncio.Queue() def submit( self, project_name: str, program: quantum.QuantumProgram, job: quantum.QuantumJob @@ -153,8 +161,12 @@ def submit( raise ValueError('Program name must be set.') if self._manage_stream_loop_future is None or self._manage_stream_loop_future.done(): - self._manage_stream_loop_future = self._executor.submit(self._manage_stream) - return self._executor.submit(self._manage_execution, project_name, program, job) + self._manage_stream_loop_future = self._executor.submit( + self._manage_stream, self._request_queue + ) + return self._executor.submit( + self._manage_execution, self._request_queue, project_name, program, job + ) def stop(self) -> None: """Closes the open stream and resets all management resources.""" @@ -168,9 +180,9 @@ def stop(self) -> None: def _reset(self): """Resets the manager state.""" - self._request_queue = None self._manage_stream_loop_future = None self._response_demux = ResponseDemux() + self._request_queue = self._executor.submit(self._make_request_queue).result() @property def _executor(self) -> AsyncioExecutor: @@ -178,7 +190,9 @@ def _executor(self) -> AsyncioExecutor: # clients: https://github.com/grpc/grpc/issues/25364. return AsyncioExecutor.instance() - async def _manage_stream(self) -> None: + async def _manage_stream( + self, request_queue: asyncio.Queue[Optional[quantum.QuantumRunStreamRequest]] + ) -> None: """The stream coroutine, an asyncio coroutine to manage QuantumRunStream. This coroutine reads responses from the stream and forwards them to the ResponseDemux, where @@ -187,25 +201,32 @@ async def _manage_stream(self) -> None: When the stream breaks, the stream is reopened, and all execution coroutines are notified. There is at most a single instance of this coroutine running. + + Args: + request_queue: The queue holding requests from the execution coroutine. """ - self._request_queue = asyncio.Queue() while True: try: # The default gRPC client timeout is used. response_iterable = await self._grpc_client.quantum_run_stream( - _request_iterator(self._request_queue) + _request_iterator(request_queue) ) async for response in response_iterable: self._response_demux.publish(response) except asyncio.CancelledError: + await request_queue.put(None) break except BaseException as e: - # TODO(#5996) Close the request iterator to close the existing stream. # Note: the message ID counter is not reset upon a new stream. + await request_queue.put(None) self._response_demux.publish_exception(e) # Raise to all request tasks async def _manage_execution( - self, project_name: str, program: quantum.QuantumProgram, job: quantum.QuantumJob + self, + request_queue: asyncio.Queue[Optional[quantum.QuantumRunStreamRequest]], + project_name: str, + program: quantum.QuantumProgram, + job: quantum.QuantumJob, ) -> Union[quantum.QuantumResult, quantum.QuantumJob]: """The execution coroutine, an asyncio coroutine to manage the lifecycle of a job execution. @@ -216,8 +237,20 @@ async def _manage_execution( error by sending another request. The exact request type depends on the error. There is one execution coroutine per running job submission. + + Args: + request_queue: The queue used to send requests to the stream coroutine. + project_name: The full project ID resource path associated with the job. + program: The Quantum Engine program representing the circuit to be executed. + job: The Quantum Engine job to be executed. + + Raises: + concurrent.futures.CancelledError: if either the request is cancelled or the stream + coroutine is cancelled. + google.api_core.exceptions.GoogleAPICallError: if the stream breaks with a non-retryable + error. + ValueError: if the response is of a type which is not recognized by this client. """ - # Construct requests ahead of time to be reused for retries. create_program_and_job_request = quantum.QuantumRunStreamRequest( parent=project_name, create_quantum_program_and_job=quantum.CreateQuantumProgramAndJobRequest( @@ -225,19 +258,12 @@ async def _manage_execution( ), ) - while self._request_queue is None: - # Wait for the stream coroutine to start. - # Ignoring coverage since this is rarely triggered. - # TODO(#5996) Consider awaiting for the queue to become available, once it is changed - # to be local to the asyncio thread. - await asyncio.sleep(1) # pragma: no cover - current_request = create_program_and_job_request while True: try: current_request.message_id = self._generate_message_id() response_future = self._response_demux.subscribe(current_request.message_id) - await self._request_queue.put(current_request) + await request_queue.put(current_request) response = await response_future # Broken stream @@ -325,16 +351,15 @@ def _is_retryable_error(e: google_exceptions.GoogleAPICallError) -> bool: return any(isinstance(e, exception_type) for exception_type in RETRYABLE_GOOGLE_API_EXCEPTIONS) -# TODO(#5996) Add stop signal to the request iterator. async def _request_iterator( - request_queue: asyncio.Queue, + request_queue: asyncio.Queue[Optional[quantum.QuantumRunStreamRequest]], ) -> AsyncIterator[quantum.QuantumRunStreamRequest]: """The request iterator for Quantum Engine client RPC quantum_run_stream(). Every call to this method generates a new iterator. """ - while True: - yield await request_queue.get() + while request := await request_queue.get(): + yield request def _to_create_job_request( diff --git a/cirq-google/cirq_google/engine/stream_manager_test.py b/cirq-google/cirq_google/engine/stream_manager_test.py index 3732547cdca..42e6defbcc8 100644 --- a/cirq-google/cirq_google/engine/stream_manager_test.py +++ b/cirq-google/cirq_google/engine/stream_manager_test.py @@ -68,21 +68,26 @@ def setup(client_constructor): class FakeQuantumRunStream: """A fake Quantum Engine client which supports QuantumRunStream and CancelQuantumJob.""" + _REQUEST_STOPPED = 'REQUEST_STOPPED' + def __init__(self): self.all_stream_requests: List[quantum.QuantumRunStreamRequest] = [] self.all_cancel_requests: List[quantum.CancelQuantumJobRequest] = [] self._executor = AsyncioExecutor.instance() self._request_buffer = duet.AsyncCollector[quantum.QuantumRunStreamRequest]() + self._request_iterator_stopped = duet.AwaitableFuture() # asyncio.Queue needs to be initialized inside the asyncio thread because all callers need # to use the same event loop. - self._responses_and_exceptions_future = duet.AwaitableFuture[asyncio.Queue]() + self._responses_and_exceptions_future: duet.AwaitableFuture[ + asyncio.Queue[Union[quantum.QuantumRunStreamResponse, BaseException]] + ] = duet.AwaitableFuture() async def quantum_run_stream( self, requests: AsyncIterator[quantum.QuantumRunStreamRequest], **kwargs ) -> Awaitable[AsyncIterable[quantum.QuantumRunStreamResponse]]: """Fakes the QuantumRunStream RPC. - Once a request is received, it is appended to `stream_requests`, and the test calling + Once a request is received, it is appended to `all_stream_requests`, and the test calling `wait_for_requests()` is notified. The response is sent when a test calls `reply()` with a `QuantumRunStreamResponse`. If a @@ -91,25 +96,29 @@ async def quantum_run_stream( This is called from the asyncio thread. """ - responses_and_exceptions: asyncio.Queue = asyncio.Queue() + responses_and_exceptions: asyncio.Queue[ + Union[quantum.QuantumRunStreamResponse, BaseException] + ] = asyncio.Queue() self._responses_and_exceptions_future.try_set_result(responses_and_exceptions) async def read_requests(): async for request in requests: self.all_stream_requests.append(request) self._request_buffer.add(request) + await responses_and_exceptions.put(FakeQuantumRunStream._REQUEST_STOPPED) + self._request_iterator_stopped.try_set_result(None) async def response_iterator(): asyncio.create_task(read_requests()) - while True: - response_or_exception = await responses_and_exceptions.get() - if isinstance(response_or_exception, quantum.QuantumRunStreamResponse): - yield response_or_exception - else: # isinstance(response_or_exception, BaseException) - self._responses_and_exceptions_future = duet.AwaitableFuture[asyncio.Queue]() - raise response_or_exception + while ( + message := await responses_and_exceptions.get() + ) != FakeQuantumRunStream._REQUEST_STOPPED: + if isinstance(message, quantum.QuantumRunStreamResponse): + yield message + else: # isinstance(message, BaseException) + self._responses_and_exceptions_future = duet.AwaitableFuture() + raise message - await asyncio.sleep(0) return response_iterator() async def cancel_quantum_job(self, request: quantum.CancelQuantumJobRequest) -> None: @@ -158,6 +167,14 @@ async def send(): await self._executor.submit(send) + async def wait_for_request_iterator_stop(self): + """Wait for the request iterator to stop. + + This must be called from a duet thread. + """ + await self._request_iterator_stopped + self._request_iterator_stopped = duet.AwaitableFuture() + class TestResponseDemux: @pytest.fixture @@ -704,3 +721,91 @@ def test_get_retry_request_or_raise_expects_stream_error( create_quantum_program_and_job_request, create_quantum_job_request, ) + + @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) + def test_broken_stream_stops_request_iterator(self, client_constructor): + expected_result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') + fake_client, manager = setup(client_constructor) + + async def test(): + async with duet.timeout_scope(5): + actual_result_future = manager.submit( + REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0 + ) + await fake_client.wait_for_requests() + await fake_client.reply( + quantum.QuantumRunStreamResponse( + message_id=fake_client.all_stream_requests[0].message_id, + result=expected_result, + ) + ) + await actual_result_future + await fake_client.reply(google_exceptions.ServiceUnavailable('service unavailable')) + await fake_client.wait_for_request_iterator_stop() + manager.stop() + + duet.run(test) + + @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) + def test_stop_stops_request_iterator(self, client_constructor): + expected_result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') + fake_client, manager = setup(client_constructor) + + async def test(): + async with duet.timeout_scope(5): + actual_result_future = manager.submit( + REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0 + ) + await fake_client.wait_for_requests() + await fake_client.reply( + quantum.QuantumRunStreamResponse( + message_id=fake_client.all_stream_requests[0].message_id, + result=expected_result, + ) + ) + await actual_result_future + manager.stop() + await fake_client.wait_for_request_iterator_stop() + + duet.run(test) + + @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) + def test_submit_after_stream_breakage(self, client_constructor): + expected_result0 = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') + expected_result1 = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job1') + fake_client, manager = setup(client_constructor) + + async def test(): + async with duet.timeout_scope(5): + actual_result0_future = manager.submit( + REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0 + ) + await fake_client.wait_for_requests() + await fake_client.reply( + quantum.QuantumRunStreamResponse( + message_id=fake_client.all_stream_requests[0].message_id, + result=expected_result0, + ) + ) + actual_result0 = await actual_result0_future + await fake_client.reply(google_exceptions.ServiceUnavailable('service unavailable')) + actual_result1_future = manager.submit( + REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0 + ) + await fake_client.wait_for_requests() + await fake_client.reply( + quantum.QuantumRunStreamResponse( + message_id=fake_client.all_stream_requests[1].message_id, + result=expected_result1, + ) + ) + actual_result1 = await actual_result1_future + manager.stop() + + assert len(fake_client.all_stream_requests) == 2 + assert 'create_quantum_program_and_job' in fake_client.all_stream_requests[0] + assert 'create_quantum_program_and_job' in fake_client.all_stream_requests[1] + assert actual_result0 == expected_result0 + assert actual_result1 == expected_result1 + + duet.run(test) From cf005c2d59ba8acc3ffc072aa9cb3015b59915f2 Mon Sep 17 00:00:00 2001 From: Spencer Churchill Date: Tue, 12 Sep 2023 19:36:10 -0700 Subject: [PATCH 06/19] add to ionq code owners (#6273) * add to ionq code owners * pass tests --- .github/CODEOWNERS | 4 ++-- dev_tools/codeowners_test.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 866d69af216..e84e63f19a2 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -10,7 +10,7 @@ cirq-google/**/*.* @wcourtney @quantumlib/cirq-maintainers @vtomole @cduck @verult -cirq-ionq/**/*.* @dabacon @ColemanCollins @nakardo @gmauricio @Cynocracy @quantumlib/cirq-maintainers @vtomole @cduck +cirq-ionq/**/*.* @dabacon @ColemanCollins @nakardo @gmauricio @Cynocracy @quantumlib/cirq-maintainers @vtomole @cduck @splch cirq-aqt/**/*.* @ma5x @pschindler @alfrisch @quantumlib/cirq-maintainers @vtomole @cduck @@ -37,7 +37,7 @@ docs/**/*.* @aasfaw @rmlarose @quantumlib/cirq-maintainers @vtomole @cduck docs/google/**/*.* @wcourtney @aasfaw @rmlarose @quantumlib/cirq-maintainers @vtomole @cduck @verult docs/tutorials/google/**/*.* @wcourtney @aasfaw @rmlarose @quantumlib/cirq-maintainers @vtomole @cduck @verult -docs/hardware/ionq/**/*.* @dabacon @ColemanCollins @nakardo @gmauricio @aasfaw @rmlarose @Cynocracy @quantumlib/cirq-maintainers @vtomole @cduck +docs/hardware/ionq/**/*.* @dabacon @ColemanCollins @nakardo @gmauricio @aasfaw @rmlarose @Cynocracy @quantumlib/cirq-maintainers @vtomole @cduck @splch docs/hardware/aqt/**/*.* @ma5x @pschindler @alfrisch @aasfaw @rmlarose @quantumlib/cirq-maintainers @vtomole @cduck diff --git a/dev_tools/codeowners_test.py b/dev_tools/codeowners_test.py index b0e788371f5..64c6786acb9 100644 --- a/dev_tools/codeowners_test.py +++ b/dev_tools/codeowners_test.py @@ -26,7 +26,8 @@ GOOGLE_MAINTAINERS = BASE_MAINTAINERS.union(GOOGLE_TEAM) IONQ_TEAM = { - ('USERNAME', u) for u in ["@dabacon", "@ColemanCollins", "@nakardo", "@gmauricio", "@Cynocracy"] + ('USERNAME', u) + for u in ["@dabacon", "@ColemanCollins", "@nakardo", "@gmauricio", "@Cynocracy", "@splch"] } IONQ_MAINTAINERS = BASE_MAINTAINERS.union(IONQ_TEAM) From 432d57a8510a9927513225de91ba2b27bf88cefd Mon Sep 17 00:00:00 2001 From: Suyash Damle Date: Thu, 14 Sep 2023 07:42:32 +0000 Subject: [PATCH 07/19] Add serialization support for InsertionNoiseModel (#6282) --- .../cirq/devices/insertion_noise_model.py | 22 ++++- .../devices/insertion_noise_model_test.py | 6 ++ cirq-core/cirq/json_resolver_cache.py | 2 + .../json_test_data/InsertionNoiseModel.json | 91 +++++++++++++++++++ .../json_test_data/InsertionNoiseModel.repr | 4 + 5 files changed, 124 insertions(+), 1 deletion(-) create mode 100644 cirq-core/cirq/protocols/json_test_data/InsertionNoiseModel.json create mode 100644 cirq-core/cirq/protocols/json_test_data/InsertionNoiseModel.repr diff --git a/cirq-core/cirq/devices/insertion_noise_model.py b/cirq-core/cirq/devices/insertion_noise_model.py index ab6604868fc..cbe158ae8a5 100644 --- a/cirq-core/cirq/devices/insertion_noise_model.py +++ b/cirq-core/cirq/devices/insertion_noise_model.py @@ -13,7 +13,7 @@ # limitations under the License. import dataclasses -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence from cirq import devices from cirq.devices import noise_utils @@ -74,3 +74,23 @@ def noisy_moment( if self.prepend: return [*noise_steps.moments, moment] return [moment, *noise_steps.moments] + + def __repr__(self) -> str: + return ( + f'cirq.devices.InsertionNoiseModel(ops_added={self.ops_added},' + + f' prepend={self.prepend},' + + f' require_physical_tag={self.require_physical_tag})' + ) + + def _json_dict_(self) -> Dict[str, Any]: + return { + 'ops_added': list(self.ops_added.items()), + 'prepend': self.prepend, + 'require_physical_tag': self.require_physical_tag, + } + + @classmethod + def _from_json_dict_(cls, ops_added, prepend, require_physical_tag, **kwargs): + return cls( + ops_added=dict(ops_added), prepend=prepend, require_physical_tag=require_physical_tag + ) diff --git a/cirq-core/cirq/devices/insertion_noise_model_test.py b/cirq-core/cirq/devices/insertion_noise_model_test.py index dc15eec2b44..3a316b805e6 100644 --- a/cirq-core/cirq/devices/insertion_noise_model_test.py +++ b/cirq-core/cirq/devices/insertion_noise_model_test.py @@ -47,6 +47,8 @@ def test_insertion_noise(): moment_3 = cirq.Moment(cirq.Z(q0), cirq.X(q1)) assert model.noisy_moment(moment_3, system_qubits=[q0, q1]) == [moment_3] + cirq.testing.assert_equivalent_repr(model) + def test_colliding_noise_qubits(): # Check that noise affecting other qubits doesn't cause issues. @@ -61,6 +63,8 @@ def test_colliding_noise_qubits(): cirq.Moment(cirq.CNOT(q1, q2)), ] + cirq.testing.assert_equivalent_repr(model) + def test_prepend(): q0, q1 = cirq.LineQubit.range(2) @@ -106,3 +110,5 @@ def test_supertype_matching(): moment_1 = cirq.Moment(cirq.Y(q0)) assert model.noisy_moment(moment_1, system_qubits=[q0]) == [moment_1, cirq.Moment(cirq.T(q0))] + + cirq.testing.assert_equivalent_repr(model) diff --git a/cirq-core/cirq/json_resolver_cache.py b/cirq-core/cirq/json_resolver_cache.py index 20a7377b294..f1d178bb530 100644 --- a/cirq-core/cirq/json_resolver_cache.py +++ b/cirq-core/cirq/json_resolver_cache.py @@ -47,6 +47,7 @@ def _class_resolver_dictionary() -> Dict[str, ObjectFactory]: import pandas as pd import numpy as np from cirq.devices.noise_model import _NoNoiseModel + from cirq.devices import InsertionNoiseModel from cirq.experiments import GridInteractionLayer from cirq.experiments.grid_parallel_two_qubit_xeb import GridParallelXEBMetadata @@ -147,6 +148,7 @@ def _symmetricalqidpair(qids): 'ISwapPowGate': cirq.ISwapPowGate, 'IdentityGate': cirq.IdentityGate, 'InitObsSetting': cirq.work.InitObsSetting, + 'InsertionNoiseModel': InsertionNoiseModel, 'KeyCondition': cirq.KeyCondition, 'KrausChannel': cirq.KrausChannel, 'LinearDict': cirq.LinearDict, diff --git a/cirq-core/cirq/protocols/json_test_data/InsertionNoiseModel.json b/cirq-core/cirq/protocols/json_test_data/InsertionNoiseModel.json new file mode 100644 index 00000000000..1a825bdbe48 --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/InsertionNoiseModel.json @@ -0,0 +1,91 @@ +[ + { + "cirq_type": "InsertionNoiseModel", + "ops_added": [ + [ + { + "cirq_type": "OpIdentifier", + "gate_type": "XPowGate", + "qubits": [ + { + "cirq_type": "LineQubit", + "x": 0 + } + ] + }, + { + "cirq_type": "GateOperation", + "gate": { + "cirq_type": "BitFlipChannel", + "p": 0.2 + }, + "qubits": [ + { + "cirq_type": "LineQubit", + "x": 0 + } + ] + } + ] + ], + "prepend": false, + "require_physical_tag": false + }, + { + "cirq_type": "InsertionNoiseModel", + "ops_added": [ + [ + { + "cirq_type": "OpIdentifier", + "gate_type": "XPowGate", + "qubits": [ + { + "cirq_type": "LineQubit", + "x": 0 + } + ] + }, + { + "cirq_type": "GateOperation", + "gate": { + "cirq_type": "BitFlipChannel", + "p": 0.2 + }, + "qubits": [ + { + "cirq_type": "LineQubit", + "x": 0 + } + ] + } + ], + [ + { + "cirq_type": "OpIdentifier", + "gate_type": "HPowGate", + "qubits": [ + { + "cirq_type": "LineQubit", + "x": 1 + } + ] + }, + { + "cirq_type": "GateOperation", + "gate": { + "cirq_type": "BitFlipChannel", + "p": 0.1 + }, + "qubits": [ + { + "cirq_type": "LineQubit", + "x": 1 + } + ] + } + ] + ], + "prepend": false, + "require_physical_tag": false + } +] \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/InsertionNoiseModel.repr b/cirq-core/cirq/protocols/json_test_data/InsertionNoiseModel.repr new file mode 100644 index 00000000000..d0650e0d6fc --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/InsertionNoiseModel.repr @@ -0,0 +1,4 @@ +[ +cirq.devices.InsertionNoiseModel(ops_added={cirq.devices.noise_utils.OpIdentifier(cirq.ops.common_gates.XPowGate, cirq.LineQubit(0)): cirq.bit_flip(p=0.2).on(cirq.LineQubit(0))}, prepend=False, require_physical_tag=False), +cirq.devices.InsertionNoiseModel(ops_added={cirq.devices.noise_utils.OpIdentifier(cirq.ops.common_gates.XPowGate, cirq.LineQubit(0)): cirq.bit_flip(p=0.2).on(cirq.LineQubit(0)), cirq.devices.noise_utils.OpIdentifier(cirq.ops.common_gates.HPowGate, cirq.LineQubit(1)): cirq.bit_flip(p=0.1).on(cirq.LineQubit(1))}, prepend=False, require_physical_tag=False) +] \ No newline at end of file From f715527bdf0da4763cd196ce2be59832a530dec1 Mon Sep 17 00:00:00 2001 From: Bram Evert Date: Fri, 15 Sep 2023 01:19:05 +0100 Subject: [PATCH 08/19] Fix documentation of FSimGate (#6288) --- cirq-core/cirq/ops/fsim_gate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cirq-core/cirq/ops/fsim_gate.py b/cirq-core/cirq/ops/fsim_gate.py index b3369411324..01cf6291b8b 100644 --- a/cirq-core/cirq/ops/fsim_gate.py +++ b/cirq-core/cirq/ops/fsim_gate.py @@ -81,7 +81,7 @@ class FSimGate(gate_features.InterchangeableQubitsGate, raw_types.Gate): $$ $$ - c = e^{i \phi} + c = e^{-i \phi} $$ Note the difference in sign conventions between FSimGate and the From b630298df901b1b49d7a996fa0a1104d830174d6 Mon Sep 17 00:00:00 2001 From: Doug Strain Date: Mon, 18 Sep 2023 12:01:15 -0700 Subject: [PATCH 09/19] Fix matplotlib typing (#6290) * Fix matplotlib typing matplotlib 3.8.0 was released this week and included typing hints. This fixes the resulting CI breakages. * Fix issues. * formatting * Change to seaborn v0_8 --- cirq-core/cirq/contrib/svg/svg.py | 4 +++- cirq-core/cirq/devices/named_topologies.py | 2 +- .../cirq/experiments/qubit_characterizations.py | 11 ++++++----- cirq-core/cirq/linalg/decompositions.py | 11 ++++++----- cirq-core/cirq/vis/heatmap.py | 11 +++++++---- cirq-core/cirq/vis/heatmap_test.py | 10 ++++++++++ cirq-core/cirq/vis/histogram.py | 8 ++++---- cirq-core/cirq/vis/state_histogram.py | 16 ++++++++++------ cirq-core/cirq/vis/state_histogram_test.py | 2 ++ cirq-google/cirq_google/engine/calibration.py | 5 +++-- docs/experiments/textbook_algorithms.ipynb | 2 +- docs/start/intro.ipynb | 4 ++-- examples/two_qubit_gate_compilation.py | 2 +- 13 files changed, 56 insertions(+), 32 deletions(-) diff --git a/cirq-core/cirq/contrib/svg/svg.py b/cirq-core/cirq/contrib/svg/svg.py index be7a1d60c56..3a9be84bdeb 100644 --- a/cirq-core/cirq/contrib/svg/svg.py +++ b/cirq-core/cirq/contrib/svg/svg.py @@ -2,12 +2,14 @@ from typing import TYPE_CHECKING, List, Tuple, cast, Dict import matplotlib.textpath +import matplotlib.font_manager + if TYPE_CHECKING: import cirq QBLUE = '#1967d2' -FONT = "Arial" +FONT = matplotlib.font_manager.FontProperties(family="Arial") EMPTY_MOMENT_COLWIDTH = float(21) # assumed default column width diff --git a/cirq-core/cirq/devices/named_topologies.py b/cirq-core/cirq/devices/named_topologies.py index 5f32d8b1d5d..6aa46e19e94 100644 --- a/cirq-core/cirq/devices/named_topologies.py +++ b/cirq-core/cirq/devices/named_topologies.py @@ -74,7 +74,7 @@ def _node_and_coordinates( def draw_gridlike( - graph: nx.Graph, ax: plt.Axes = None, tilted: bool = True, **kwargs + graph: nx.Graph, ax: Optional[plt.Axes] = None, tilted: bool = True, **kwargs ) -> Dict[Any, Tuple[int, int]]: """Draw a grid-like graph using Matplotlib. diff --git a/cirq-core/cirq/experiments/qubit_characterizations.py b/cirq-core/cirq/experiments/qubit_characterizations.py index 114e2e28659..ed12b311e22 100644 --- a/cirq-core/cirq/experiments/qubit_characterizations.py +++ b/cirq-core/cirq/experiments/qubit_characterizations.py @@ -15,13 +15,13 @@ import dataclasses import itertools -from typing import Any, Iterator, List, Optional, Sequence, Tuple, TYPE_CHECKING +from typing import Any, cast, Iterator, List, Optional, Sequence, Tuple, TYPE_CHECKING import numpy as np from matplotlib import pyplot as plt # this is for older systems with matplotlib <3.2 otherwise 3d projections fail -from mpl_toolkits import mplot3d # pylint: disable=unused-import +from mpl_toolkits import mplot3d from cirq import circuits, ops, protocols if TYPE_CHECKING: @@ -89,8 +89,9 @@ def plot(self, ax: Optional[plt.Axes] = None, **plot_kwargs: Any) -> plt.Axes: """ show_plot = not ax if not ax: - fig, ax = plt.subplots(1, 1, figsize=(8, 8)) - ax.set_ylim([0, 1]) + fig, ax = plt.subplots(1, 1, figsize=(8, 8)) # pragma: no cover + ax = cast(plt.Axes, ax) # pragma: no cover + ax.set_ylim((0.0, 1.0)) # pragma: no cover ax.plot(self._num_cfds_seq, self._gnd_state_probs, 'ro-', **plot_kwargs) ax.set_xlabel(r"Number of Cliffords") ax.set_ylabel('Ground State Probability') @@ -541,7 +542,7 @@ def _find_inv_matrix(mat: np.ndarray, mat_sequence: np.ndarray) -> int: def _matrix_bar_plot( mat: np.ndarray, z_label: str, - ax: plt.Axes, + ax: mplot3d.axes3d.Axes3D, kets: Optional[Sequence[str]] = None, title: Optional[str] = None, ylim: Tuple[int, int] = (-1, 1), diff --git a/cirq-core/cirq/linalg/decompositions.py b/cirq-core/cirq/linalg/decompositions.py index 60dc0123640..43434ff4d1b 100644 --- a/cirq-core/cirq/linalg/decompositions.py +++ b/cirq-core/cirq/linalg/decompositions.py @@ -20,6 +20,7 @@ from typing import ( Any, Callable, + cast, Iterable, List, Optional, @@ -33,7 +34,7 @@ import matplotlib.pyplot as plt # this is for older systems with matplotlib <3.2 otherwise 3d projections fail -from mpl_toolkits import mplot3d # pylint: disable=unused-import +from mpl_toolkits import mplot3d import numpy as np from cirq import value, protocols @@ -554,7 +555,7 @@ def scatter_plot_normalized_kak_interaction_coefficients( interactions: Iterable[Union[np.ndarray, 'cirq.SupportsUnitary', 'KakDecomposition']], *, include_frame: bool = True, - ax: Optional[plt.Axes] = None, + ax: Optional[mplot3d.axes3d.Axes3D] = None, **kwargs, ): r"""Plots the interaction coefficients of many two-qubit operations. @@ -633,13 +634,13 @@ def scatter_plot_normalized_kak_interaction_coefficients( show_plot = not ax if not ax: fig = plt.figure() - ax = fig.add_subplot(1, 1, 1, projection='3d') + ax = cast(mplot3d.axes3d.Axes3D, fig.add_subplot(1, 1, 1, projection='3d')) def coord_transform( pts: Union[List[Tuple[int, int, int]], np.ndarray] - ) -> Tuple[Iterable[float], Iterable[float], Iterable[float]]: + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: if len(pts) == 0: - return [], [], [] + return np.array([]), np.array([]), np.array([]) xs, ys, zs = np.transpose(pts) return xs, zs, ys diff --git a/cirq-core/cirq/vis/heatmap.py b/cirq-core/cirq/vis/heatmap.py index e5598f59450..e672a2b8c27 100644 --- a/cirq-core/cirq/vis/heatmap.py +++ b/cirq-core/cirq/vis/heatmap.py @@ -15,6 +15,7 @@ from dataclasses import astuple, dataclass from typing import ( Any, + cast, Dict, List, Mapping, @@ -217,7 +218,7 @@ def _plot_colorbar( ) position = self._config['colorbar_position'] orien = 'vertical' if position in ('left', 'right') else 'horizontal' - colorbar = ax.figure.colorbar( + colorbar = cast(plt.Figure, ax.figure).colorbar( mappable, colorbar_ax, ax, orientation=orien, **self._config.get("colorbar_options", {}) ) colorbar_ax.tick_params(axis='y', direction='out') @@ -230,15 +231,15 @@ def _write_annotations( ax: plt.Axes, ) -> None: """Writes annotations to the center of cells. Internal.""" - for (center, annotation), facecolor in zip(centers_and_annot, collection.get_facecolors()): + for (center, annotation), facecolor in zip(centers_and_annot, collection.get_facecolor()): # Calculate the center of the cell, assuming that it is a square # centered at (x=col, y=row). if not annotation: continue x, y = center - face_luminance = vis_utils.relative_luminance(facecolor) + face_luminance = vis_utils.relative_luminance(facecolor) # type: ignore text_color = 'black' if face_luminance > 0.4 else 'white' - text_kwargs = dict(color=text_color, ha="center", va="center") + text_kwargs: Dict[str, Any] = dict(color=text_color, ha="center", va="center") text_kwargs.update(self._config.get('annotation_text_kwargs', {})) ax.text(x, y, annotation, **text_kwargs) @@ -295,6 +296,7 @@ def plot( show_plot = not ax if not ax: fig, ax = plt.subplots(figsize=(8, 8)) + ax = cast(plt.Axes, ax) original_config = copy.deepcopy(self._config) self.update_config(**kwargs) collection = self._plot_on_axis(ax) @@ -381,6 +383,7 @@ def plot( show_plot = not ax if not ax: fig, ax = plt.subplots(figsize=(8, 8)) + ax = cast(plt.Axes, ax) original_config = copy.deepcopy(self._config) self.update_config(**kwargs) qubits = set([q for qubits in self._value_map.keys() for q in qubits]) diff --git a/cirq-core/cirq/vis/heatmap_test.py b/cirq-core/cirq/vis/heatmap_test.py index 1ca493386f5..dceb00cff1c 100644 --- a/cirq-core/cirq/vis/heatmap_test.py +++ b/cirq-core/cirq/vis/heatmap_test.py @@ -34,6 +34,14 @@ def ax(): return figure.add_subplot(111) +def test_default_ax(): + row_col_list = ((0, 5), (8, 1), (7, 0), (13, 5), (1, 6), (3, 2), (2, 8)) + test_value_map = { + grid_qubit.GridQubit(row, col): np.random.random() for (row, col) in row_col_list + } + _, _ = heatmap.Heatmap(test_value_map).plot() + + @pytest.mark.parametrize('tuple_keys', [True, False]) def test_cells_positions(ax, tuple_keys): row_col_list = ((0, 5), (8, 1), (7, 0), (13, 5), (1, 6), (3, 2), (2, 8)) @@ -61,6 +69,8 @@ def test_two_qubit_heatmap(ax): title = "Two Qubit Interaction Heatmap" heatmap.TwoQubitInteractionHeatmap(value_map, title=title).plot(ax) assert ax.get_title() == title + # Test default axis + heatmap.TwoQubitInteractionHeatmap(value_map, title=title).plot() def test_invalid_args(): diff --git a/cirq-core/cirq/vis/histogram.py b/cirq-core/cirq/vis/histogram.py index f3b0a8047bc..88349097a97 100644 --- a/cirq-core/cirq/vis/histogram.py +++ b/cirq-core/cirq/vis/histogram.py @@ -100,9 +100,9 @@ def integrated_histogram( plot_options.update(kwargs) if cdf_on_x: - ax.step(bin_values, parameter_values, **plot_options) + ax.step(bin_values, parameter_values, **plot_options) # type: ignore else: - ax.step(parameter_values, bin_values, **plot_options) + ax.step(parameter_values, bin_values, **plot_options) # type: ignore set_semilog = ax.semilogy if cdf_on_x else ax.semilogx set_lim = ax.set_xlim if cdf_on_x else ax.set_ylim @@ -128,7 +128,7 @@ def integrated_histogram( if median_line: set_line( - np.median(float_data), + float(np.median(float_data)), linestyle='--', color=plot_options['color'], alpha=0.5, @@ -136,7 +136,7 @@ def integrated_histogram( ) if mean_line: set_line( - np.mean(float_data), + float(np.mean(float_data)), linestyle='-.', color=plot_options['color'], alpha=0.5, diff --git a/cirq-core/cirq/vis/state_histogram.py b/cirq-core/cirq/vis/state_histogram.py index 51ccfc5f073..3a3706cf04f 100644 --- a/cirq-core/cirq/vis/state_histogram.py +++ b/cirq-core/cirq/vis/state_histogram.py @@ -14,7 +14,7 @@ """Tool to visualize the results of a study.""" -from typing import Union, Optional, Sequence, SupportsFloat +from typing import cast, Optional, Sequence, SupportsFloat, Union import collections import numpy as np import matplotlib.pyplot as plt @@ -51,13 +51,13 @@ def get_state_histogram(result: 'result.Result') -> np.ndarray: def plot_state_histogram( data: Union['result.Result', collections.Counter, Sequence[SupportsFloat]], - ax: Optional['plt.Axis'] = None, + ax: Optional[plt.Axes] = None, *, tick_label: Optional[Sequence[str]] = None, xlabel: Optional[str] = 'qubit state', ylabel: Optional[str] = 'result count', title: Optional[str] = 'Result State Histogram', -) -> 'plt.Axis': +) -> plt.Axes: """Plot the state histogram from either a single result with repetitions or a histogram computed using `result.histogram()` or a flattened histogram of measurement results computed using `get_state_histogram`. @@ -87,6 +87,7 @@ def plot_state_histogram( show_fig = not ax if not ax: fig, ax = plt.subplots(1, 1) + ax = cast(plt.Axes, ax) if isinstance(data, result.Result): values = get_state_histogram(data) elif isinstance(data, collections.Counter): @@ -96,9 +97,12 @@ def plot_state_histogram( if tick_label is None: tick_label = [str(i) for i in range(len(values))] ax.bar(np.arange(len(values)), values, tick_label=tick_label) - ax.set_xlabel(xlabel) - ax.set_ylabel(ylabel) - ax.set_title(title) + if xlabel: + ax.set_xlabel(xlabel) + if ylabel: + ax.set_ylabel(ylabel) + if title: + ax.set_title(title) if show_fig: fig.show() return ax diff --git a/cirq-core/cirq/vis/state_histogram_test.py b/cirq-core/cirq/vis/state_histogram_test.py index 220030d0e81..a922b12b1ff 100644 --- a/cirq-core/cirq/vis/state_histogram_test.py +++ b/cirq-core/cirq/vis/state_histogram_test.py @@ -78,6 +78,8 @@ def test_plot_state_histogram_result(): for r1, r2 in zip(ax1.get_children(), ax2.get_children()): if isinstance(r1, mpl.patches.Rectangle) and isinstance(r2, mpl.patches.Rectangle): assert str(r1) == str(r2) + # Test default axis + state_histogram.plot_state_histogram(expected_values) @pytest.mark.usefixtures('closefigures') diff --git a/cirq-google/cirq_google/engine/calibration.py b/cirq-google/cirq_google/engine/calibration.py index d28434da6c0..8e0ac4c1560 100644 --- a/cirq-google/cirq_google/engine/calibration.py +++ b/cirq-google/cirq_google/engine/calibration.py @@ -17,7 +17,7 @@ from collections import abc, defaultdict import datetime from itertools import cycle -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Sequence +from typing import Any, cast, Dict, Iterator, List, Optional, Tuple, Union, Sequence import matplotlib as mpl import matplotlib.pyplot as plt @@ -277,6 +277,7 @@ def plot_histograms( show_plot = not ax if not ax: fig, ax = plt.subplots(1, 1) + ax = cast(plt.Axes, ax) if isinstance(keys, str): keys = [keys] @@ -322,7 +323,7 @@ def plot( show_plot = not fig if not fig: fig = plt.figure() - axs = fig.subplots(1, 2) + axs = cast(List[plt.Axes], fig.subplots(1, 2)) self.heatmap(key).plot(axs[0]) self.plot_histograms(key, axs[1]) if show_plot: diff --git a/docs/experiments/textbook_algorithms.ipynb b/docs/experiments/textbook_algorithms.ipynb index 182a91e5ff2..9bec52408b1 100644 --- a/docs/experiments/textbook_algorithms.ipynb +++ b/docs/experiments/textbook_algorithms.ipynb @@ -1010,7 +1010,7 @@ "outputs": [], "source": [ "\"\"\"Plot the results.\"\"\"\n", - "plt.style.use(\"seaborn-whitegrid\")\n", + "plt.style.use(\"seaborn-v0_8-whitegrid\")\n", "\n", "plt.plot(nvals, estimates, \"--o\", label=\"Phase estimation\")\n", "plt.axhline(theta, label=\"True value\", color=\"black\")\n", diff --git a/docs/start/intro.ipynb b/docs/start/intro.ipynb index 6929b08fce3..42599d0cfe2 100644 --- a/docs/start/intro.ipynb +++ b/docs/start/intro.ipynb @@ -1453,7 +1453,7 @@ " probs.append(prob[0])\n", "\n", "# Plot the probability of the ground state at each simulation step.\n", - "plt.style.use('seaborn-whitegrid')\n", + "plt.style.use('seaborn-v0_8-whitegrid')\n", "plt.plot(probs, 'o')\n", "plt.xlabel(\"Step\")\n", "plt.ylabel(\"Probability of ground state\");" @@ -1490,7 +1490,7 @@ "\n", "\n", "# Plot the probability of the ground state at each simulation step.\n", - "plt.style.use('seaborn-whitegrid')\n", + "plt.style.use('seaborn-v0_8-whitegrid')\n", "plt.plot(sampled_probs, 'o')\n", "plt.xlabel(\"Step\")\n", "plt.ylabel(\"Probability of ground state\");" diff --git a/examples/two_qubit_gate_compilation.py b/examples/two_qubit_gate_compilation.py index 2dd1a9e3260..9362ce9c12c 100644 --- a/examples/two_qubit_gate_compilation.py +++ b/examples/two_qubit_gate_compilation.py @@ -88,7 +88,7 @@ def main(samples: int = 1000, max_infidelity: float = 0.01): print(f'Maximum infidelity of "failed" compilation: {np.max(failed_infidelities_arr)}') plt.figure() - plt.hist(infidelities_arr, bins=25, range=[0, max_infidelity * 1.1]) + plt.hist(infidelities_arr, bins=25, range=(0.0, max_infidelity * 1.1)) # pragma: no cover ylim = plt.ylim() plt.plot([max_infidelity] * 2, ylim, '--', label='Maximum tabulation infidelity') plt.xlabel('Compiled gate infidelity vs target') From 13664940f44e1bee9be3bb5129046f1a951c8be9 Mon Sep 17 00:00:00 2001 From: Doug Strain Date: Mon, 18 Sep 2023 15:40:27 -0700 Subject: [PATCH 10/19] Apply minor doc fixes (#6289) - quirk_url_to_circuit and quirk_json_to_circuit had weird HTML in them. - concat_ragged had a bunch of pre tags in the output. I believe these changes fix the markdown generation. --- cirq-core/cirq/circuits/circuit.py | 16 ++-- .../cirq/interop/quirk/url_to_circuit.py | 78 ++++++++++--------- .../cirq/protocols/apply_channel_protocol.py | 15 ++-- 3 files changed, 57 insertions(+), 52 deletions(-) diff --git a/cirq-core/cirq/circuits/circuit.py b/cirq-core/cirq/circuits/circuit.py index a0c2fb0d94a..a23e87a2d4a 100644 --- a/cirq-core/cirq/circuits/circuit.py +++ b/cirq-core/cirq/circuits/circuit.py @@ -1465,14 +1465,14 @@ def concat_ragged( Beware that this method is *not* associative. For example: - >>> a, b = cirq.LineQubit.range(2) - >>> A = cirq.Circuit(cirq.H(a)) - >>> B = cirq.Circuit(cirq.H(b)) - >>> f = cirq.Circuit.concat_ragged - >>> f(f(A, B), A) == f(A, f(B, A)) - False - >>> len(f(f(f(A, B), A), B)) == len(f(f(A, f(B, A)), B)) - False + >>> a, b = cirq.LineQubit.range(2) + >>> A = cirq.Circuit(cirq.H(a)) + >>> B = cirq.Circuit(cirq.H(b)) + >>> f = cirq.Circuit.concat_ragged + >>> f(f(A, B), A) == f(A, f(B, A)) + False + >>> len(f(f(f(A, B), A), B)) == len(f(f(A, f(B, A)), B)) + False Args: *circuits: The circuits to concatenate. diff --git a/cirq-core/cirq/interop/quirk/url_to_circuit.py b/cirq-core/cirq/interop/quirk/url_to_circuit.py index 978ac574d18..8c27a44c538 100644 --- a/cirq-core/cirq/interop/quirk/url_to_circuit.py +++ b/cirq-core/cirq/interop/quirk/url_to_circuit.py @@ -77,39 +77,40 @@ def quirk_url_to_circuit( a billion laughs attack in the form of nested custom gates. Examples: - >>> print(cirq.quirk_url_to_circuit( - ... 'http://algassert.com/quirk#circuit={"cols":[["H"],["•","X"]]}' - ... )) - 0: ───H───@─── - │ - 1: ───────X─── - - >>> print(cirq.quirk_url_to_circuit( - ... 'http://algassert.com/quirk#circuit={"cols":[["H"],["•","X"]]}', - ... qubits=[cirq.NamedQubit('Alice'), cirq.NamedQubit('Bob')] - ... )) - Alice: ───H───@─── - │ - Bob: ─────────X─── - - >>> print(cirq.quirk_url_to_circuit( - ... 'http://algassert.com/quirk#circuit={"cols":[["iswap"]]}', - ... extra_cell_makers={'iswap': cirq.ISWAP})) - 0: ───iSwap─── - │ - 1: ───iSwap─── - - >>> print(cirq.quirk_url_to_circuit( - ... 'http://algassert.com/quirk#circuit={"cols":[["iswap"]]}', - ... extra_cell_makers=[ - ... cirq.interop.quirk.cells.CellMaker( - ... identifier='iswap', - ... size=2, - ... maker=lambda args: cirq.ISWAP(*args.qubits)) - ... ])) - 0: ───iSwap─── + + >>> print(cirq.quirk_url_to_circuit( + ... 'http://algassert.com/quirk#circuit={"cols":[["H"],["•","X"]]}' + ... )) + 0: ───H───@─── │ - 1: ───iSwap─── + 1: ───────X─── + + >>> print(cirq.quirk_url_to_circuit( + ... 'http://algassert.com/quirk#circuit={"cols":[["H"],["•","X"]]}', + ... qubits=[cirq.NamedQubit('Alice'), cirq.NamedQubit('Bob')] + ... )) + Alice: ───H───@─── + │ + Bob: ─────────X─── + + >>> print(cirq.quirk_url_to_circuit( + ... 'http://algassert.com/quirk#circuit={"cols":[["iswap"]]}', + ... extra_cell_makers={'iswap': cirq.ISWAP})) + 0: ───iSwap─── + │ + 1: ───iSwap─── + + >>> print(cirq.quirk_url_to_circuit( + ... 'http://algassert.com/quirk#circuit={"cols":[["iswap"]]}', + ... extra_cell_makers=[ + ... cirq.interop.quirk.cells.CellMaker( + ... identifier='iswap', + ... size=2, + ... maker=lambda args: cirq.ISWAP(*args.qubits)) + ... ])) + 0: ───iSwap─── + │ + 1: ───iSwap─── Returns: The parsed circuit. @@ -172,12 +173,13 @@ def quirk_json_to_circuit( a billion laughs attack in the form of nested custom gates. Examples: - >>> print(cirq.quirk_json_to_circuit( - ... {"cols":[["H"], ["•", "X"]]} - ... )) - 0: ───H───@─── - │ - 1: ───────X─── + + >>> print(cirq.quirk_json_to_circuit( + ... {"cols":[["H"], ["•", "X"]]} + ... )) + 0: ───H───@─── + │ + 1: ───────X─── Returns: The parsed circuit. diff --git a/cirq-core/cirq/protocols/apply_channel_protocol.py b/cirq-core/cirq/protocols/apply_channel_protocol.py index cd1a58b75d1..a6c8f283718 100644 --- a/cirq-core/cirq/protocols/apply_channel_protocol.py +++ b/cirq-core/cirq/protocols/apply_channel_protocol.py @@ -41,13 +41,16 @@ class ApplyChannelArgs: r"""Arguments for efficiently performing a channel. A channel performs the mapping - $$ - X \rightarrow \sum_i A_i X A_i^\dagger - $$ + + $$ + X \rightarrow \sum_i A_i X A_i^\dagger + $$ + for operators $A_i$ that satisfy the normalization condition - $$ - \sum_i A_i^\dagger A_i = I. - $$ + + $$ + \sum_i A_i^\dagger A_i = I. + $$ The receiving object is expected to mutate `target_tensor` so that it contains the density matrix after multiplication, and then return From d805d82375d221237d5dfe44d7c089c6911a0462 Mon Sep 17 00:00:00 2001 From: Matthew Neeley Date: Tue, 19 Sep 2023 20:03:13 +0000 Subject: [PATCH 11/19] Make InternalGate hashable if all gate args are hashable (#6294) Review: @NoureldinYosri --- cirq-google/cirq_google/ops/internal_gate.py | 15 +++++++++-- .../cirq_google/ops/internal_gate_test.py | 26 ++++++++++++++++++- 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/cirq-google/cirq_google/ops/internal_gate.py b/cirq-google/cirq_google/ops/internal_gate.py index 5822aa1fefa..f5e1f37d498 100644 --- a/cirq-google/cirq_google/ops/internal_gate.py +++ b/cirq-google/cirq_google/ops/internal_gate.py @@ -43,7 +43,7 @@ def __init__( self.gate_module = gate_module self.gate_name = gate_name self._num_qubits = num_qubits - self.gate_args = {arg: val for arg, val in kwargs.items()} + self.gate_args = kwargs def _num_qubits_(self) -> int: return self._num_qubits @@ -72,4 +72,15 @@ def _json_dict_(self) -> Dict[str, Any]: ) def _value_equality_values_(self): - return (self.gate_module, self.gate_name, self._num_qubits, self.gate_args) + hashable = True + for arg in self.gate_args.values(): + try: + hash(arg) + except TypeError: + hashable = False + return ( + self.gate_module, + self.gate_name, + self._num_qubits, + frozenset(self.gate_args.items()) if hashable else self.gate_args, + ) diff --git a/cirq-google/cirq_google/ops/internal_gate_test.py b/cirq-google/cirq_google/ops/internal_gate_test.py index 00fd480ccaa..b212d4f6151 100644 --- a/cirq-google/cirq_google/ops/internal_gate_test.py +++ b/cirq-google/cirq_google/ops/internal_gate_test.py @@ -14,6 +14,7 @@ import cirq import cirq_google +import pytest def test_internal_gate(): @@ -39,7 +40,30 @@ def test_internal_gate_with_no_args(): g = cirq_google.InternalGate(gate_name="GateWithNoArgs", gate_module='test', num_qubits=3) assert str(g) == 'test.GateWithNoArgs()' want_repr = ( - "cirq_google.InternalGate(gate_name='GateWithNoArgs', " "gate_module='test', num_qubits=3)" + "cirq_google.InternalGate(gate_name='GateWithNoArgs', gate_module='test', num_qubits=3)" ) assert repr(g) == want_repr assert cirq.qid_shape(g) == (2, 2, 2) + + +def test_internal_gate_with_hashable_args_is_hashable(): + hashable = cirq_google.InternalGate( + gate_name="GateWithHashableArgs", + gate_module='test', + num_qubits=3, + foo=1, + bar="2", + baz=(("a", 1),), + ) + _ = hash(hashable) + + unhashable = cirq_google.InternalGate( + gate_name="GateWithHashableArgs", + gate_module='test', + num_qubits=3, + foo=1, + bar="2", + baz={"a": 1}, + ) + with pytest.raises(TypeError, match="unhashable"): + _ = hash(unhashable) From 907ec3afc23e5b3e08d64aeb9cfb342ae11a01a0 Mon Sep 17 00:00:00 2001 From: eliottrosenberg <61400172+eliottrosenberg@users.noreply.github.com> Date: Tue, 19 Sep 2023 21:55:04 -0400 Subject: [PATCH 12/19] Try to make docstring render correctly (#6283) * Try to make docstring render correctly Docstring does not render correctly on the website: https://quantumai.google/reference/python/cirq/GeneralizedAmplitudeDampingChannel --- cirq-core/cirq/ops/common_channels.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/cirq-core/cirq/ops/common_channels.py b/cirq-core/cirq/ops/common_channels.py index 7867b7e6cfa..75a47f1b4dc 100644 --- a/cirq-core/cirq/ops/common_channels.py +++ b/cirq-core/cirq/ops/common_channels.py @@ -400,13 +400,10 @@ class GeneralizedAmplitudeDampingChannel(raw_types.Gate): This channel evolves a density matrix via $$ - \rho \rightarrow M_0 \rho M_0^\dagger - + M_1 \rho M_1^\dagger - + M_2 \rho M_2^\dagger - + M_3 \rho M_3^\dagger + \rho \rightarrow \sum_{i=0}^3 M_i \rho M_i^\dagger $$ - With + with $$ \begin{aligned} From 188bb94a09df38ad12b68ac220c648d068f271b3 Mon Sep 17 00:00:00 2001 From: Doug Strain Date: Tue, 19 Sep 2023 21:44:45 -0700 Subject: [PATCH 13/19] Add registry sweep documentation to cirq_google (#6291) * Add registry sweep documentation to cirq_google - Add some simple documentation on how to use registry sweeps to cirq_google's device page (currently public but not linked to from table of contents) --- docs/google/devices.md | 69 +++++++++++++++++++++--------------------- 1 file changed, 35 insertions(+), 34 deletions(-) diff --git a/docs/google/devices.md b/docs/google/devices.md index 04db4b828f9..572eeffd86f 100644 --- a/docs/google/devices.md +++ b/docs/google/devices.md @@ -27,7 +27,7 @@ of total circuit run-time. Circuits that exceed this limit will return a ### Moment structure -The hardware will attempt to run your circuit as it exists in cirq to the +The hardware will attempt to run your circuit as it exists in Cirq to the extent possible. The device will respect the moment structure of your circuit and will execute successive moments in a serial fashion. @@ -65,11 +65,40 @@ cirq.Circuit( The duration of a moment is the time of its longest gate. For example, if a moment has gates of duration 12ns, 25ns, and 32ns, the entire moment -will take 32ns. Qubits executing the shorter gtes will idle during the rest +will take 32ns. Qubits executing the shorter gates will idle during the rest of the time. To minimize the duration of the circuit, it is best to align gates of the same duration together when possible. See the [best practices](./best_practices.ipynb) for more details. +## Device Parameter Sweeps + +Certain device parameters can be changed for the duration of +a circuit in order to support hardware parameter sweeps. For instance, +frequencies, amplitudes, and various other parameters can be modified +in order to find optimal values or explore the parameter space. + +These parameter names are generally not public, so you will need to +work with a Google sponsor or resident in order to access the proper +key names. These parameters are specified as lists of strings representing +a path from the device config's folder (or the "sample folder"). + +These keys can be swept like any other symbol using the +`cirq_google.study.DeviceParameter` variable. For instance, the +following code will sweep qubit (4,8)'s pi amplitude from 0.0 to 1.0 +in 0.02 increments. + + +``` +descriptor = cirq_google.study.DeviceParameter( ["q4_8", "piAmp"]) +sweep = cirq.Linspace("q4_8.piAmp", 0, 1, 51, metadata=descriptor) +``` + +Any `DeviceParameter` keys that are set to a single value using a `cirq.Points` +object will change that value for all circuits run. + +If units are required, they should be specified as a string (such as 'MHz') +using the `units` argument of the `DeviceParameter`. + ## Gates supported The following lists the gates supported by Google devices. @@ -188,7 +217,7 @@ $$ This gate has a duration of 32ns and can be used in `cirq_google.SQRT_ISWAP_GATESET` or in the `cirq_google.FSIM_GATESET`. -This gate is implemented by using an entangling gate surrounding by +This gate is implemented by using an entangling gate surrounded by Z gates. The preceding Z gates are physical Z gates and will absorb any phases that have accumulated through the use of Virtual Z gates. Following the entangler are virtual Z gates to match phases back. All @@ -238,7 +267,8 @@ expressions, but only a subset of Sympy expression types are supported: `sympy.Symbol`, `sympy.Add`, `sympy.Mul`, and `sympy.Pow`. ## Specific Device Layouts -The following devices are provided as part of cirq and can help you get your + +The following devices are provided as part of Cirq and can help you get your circuit ready for running on hardware by verifying that you are using appropriate qubits. @@ -272,7 +302,7 @@ It can be accessed using `cirq.GridQubit(row, col)` using grid coordinates speci 9 ----I----- ``` -It can be accessing by using `cirq_google.Sycamore`. This device has two possible +It can be accessed by using `cirq_google.Sycamore`. This device has two possible two-qubits gates that can be used. * Square root of ISWAP. The gate `cirq.ISWAP ** 0.5` or `cirq.ISWAP ** -0.5` can be @@ -304,32 +334,3 @@ with and presents less hardware-related complications than using the full Sycamo This grid can be accessed using `cirq_google.Sycamore23` and uses the same gate sets and compilation as the Sycamore device. - - -### Bristlecone - -The Bristlecone processor is a 72 qubit device -[announced by Google in 2018](https://ai.googleblog.com/2018/03/a-preview-of-bristlecone-googles-new.html). - -The device is arrayed on a grid in a diamond pattern like this. - -``` - 11 - 012345678901 -0 -----AB----- -1 ----ABCD---- -2 ---ABCDEF--- -3 --ABCDEFGH-- -4 -ABCDEFGHIJ- -5 ABCDEFGHIJKL -6 -CDEFGHIJKL- -7 --EFGHIJKL-- -8 ---GHIJKL--- -9 ----IJKL---- -10-----KL----- -``` - -It can be accessing by using `cirq_google.Bristlecone`. Circuits can be compiled to it by using -`cirq_google.optimized_for_xmon` or by using `cirq_google.optimized_for_sycamore` with -optimizer_type `xmon`. - From 8e4e7d147d6d4ea5dd8111ba21a715aa3acd955c Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Wed, 20 Sep 2023 13:23:44 -0700 Subject: [PATCH 14/19] Add bitsize field to Cirq-FT Registers (#6286) --- cirq-ft/cirq_ft/algos/swap_network.py | 4 +- .../cirq_ft/infra/gate_with_registers.ipynb | 6 +-- cirq-ft/cirq_ft/infra/gate_with_registers.py | 44 ++++++++++++------- .../cirq_ft/infra/gate_with_registers_test.py | 15 ++++--- 4 files changed, 43 insertions(+), 26 deletions(-) diff --git a/cirq-ft/cirq_ft/algos/swap_network.py b/cirq-ft/cirq_ft/algos/swap_network.py index 279ab33be38..1dd5ca88879 100644 --- a/cirq-ft/cirq_ft/algos/swap_network.py +++ b/cirq-ft/cirq_ft/algos/swap_network.py @@ -152,7 +152,9 @@ def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: @cached_property def target_registers(self) -> Tuple[infra.Register, ...]: - return (infra.Register('target', (self.n_target_registers, self.target_bitsize)),) + return ( + infra.Register('target', bitsize=self.target_bitsize, shape=self.n_target_registers), + ) @cached_property def registers(self) -> infra.Registers: diff --git a/cirq-ft/cirq_ft/infra/gate_with_registers.ipynb b/cirq-ft/cirq_ft/infra/gate_with_registers.ipynb index 6afb6d49d4f..ef72a1e1479 100644 --- a/cirq-ft/cirq_ft/infra/gate_with_registers.ipynb +++ b/cirq-ft/cirq_ft/infra/gate_with_registers.ipynb @@ -39,7 +39,7 @@ "source": [ "## `Registers`\n", "\n", - "`Register` objects have a name and a shape. `Registers` is an ordered collection of `Register` with some helpful methods." + "`Register` objects have a name, a bitsize and a shape. `Registers` is an ordered collection of `Register` with some helpful methods." ] }, { @@ -51,8 +51,8 @@ "source": [ "from cirq_ft import Register, Registers, infra\n", "\n", - "control_reg = Register(name='control', shape=(2,))\n", - "target_reg = Register(name='target', shape=(3,))\n", + "control_reg = Register(name='control', bitsize=2)\n", + "target_reg = Register(name='target', bitsize=3)\n", "control_reg, target_reg" ] }, diff --git a/cirq-ft/cirq_ft/infra/gate_with_registers.py b/cirq-ft/cirq_ft/infra/gate_with_registers.py index 7139b66a65a..b4567591c67 100644 --- a/cirq-ft/cirq_ft/infra/gate_with_registers.py +++ b/cirq-ft/cirq_ft/infra/gate_with_registers.py @@ -32,8 +32,9 @@ class Register: """ name: str + bitsize: int shape: Tuple[int, ...] = attr.field( - converter=lambda v: (v,) if isinstance(v, int) else tuple(v) + converter=lambda v: (v,) if isinstance(v, int) else tuple(v), default=() ) def all_idxs(self) -> Iterable[Tuple[int, ...]]: @@ -45,15 +46,14 @@ def total_bits(self) -> int: This is the product of each of the dimensions in `shape`. """ - return int(np.product(self.shape)) + return self.bitsize * int(np.product(self.shape)) def __repr__(self): - return f'cirq_ft.Register(name="{self.name}", shape={self.shape})' + return f'cirq_ft.Register(name="{self.name}", bitsize={self.bitsize}, shape={self.shape})' def total_bits(registers: Iterable[Register]) -> int: """Sum of `reg.total_bits()` for each register `reg` in input `registers`.""" - return sum(reg.total_bits() for reg in registers) @@ -65,7 +65,9 @@ def split_qubits( qubit_regs = {} base = 0 for reg in registers: - qubit_regs[reg.name] = np.array(qubits[base : base + reg.total_bits()]).reshape(reg.shape) + qubit_regs[reg.name] = np.array(qubits[base : base + reg.total_bits()]).reshape( + reg.shape + (reg.bitsize,) + ) base += reg.total_bits() return qubit_regs @@ -82,9 +84,10 @@ def merge_qubits( raise ValueError(f"All qubit registers must be present. {reg.name} not in qubit_regs") qubits = qubit_regs[reg.name] qubits = np.array([qubits] if isinstance(qubits, cirq.Qid) else qubits) - if qubits.shape != reg.shape: + full_shape = reg.shape + (reg.bitsize,) + if qubits.shape != full_shape: raise ValueError( - f'{reg.name} register must of shape {reg.shape} but is of shape {qubits.shape}' + f'{reg.name} register must of shape {full_shape} but is of shape {qubits.shape}' ) ret += qubits.flatten().tolist() return ret @@ -94,13 +97,16 @@ def get_named_qubits(registers: Iterable[Register]) -> Dict[str, NDArray[cirq.Qi """Returns a dictionary of appropriately shaped named qubit registers for input `registers`.""" def _qubit_array(reg: Register): - qubits = np.empty(reg.shape, dtype=object) + qubits = np.empty(reg.shape + (reg.bitsize,), dtype=object) for ii in reg.all_idxs(): - qubits[ii] = cirq.NamedQubit(f'{reg.name}[{", ".join(str(i) for i in ii)}]') + for j in range(reg.bitsize): + prefix = "" if not ii else f'[{", ".join(str(i) for i in ii)}]' + suffix = "" if reg.bitsize == 1 else f"[{j}]" + qubits[ii + (j,)] = cirq.NamedQubit(reg.name + prefix + suffix) return qubits def _qubits_for_reg(reg: Register): - if len(reg.shape) > 1: + if len(reg.shape) > 0: return _qubit_array(reg) return np.array( @@ -130,8 +136,8 @@ def __repr__(self): return f'cirq_ft.Registers({self._registers})' @classmethod - def build(cls, **registers: Union[int, Tuple[int, ...]]) -> 'Registers': - return cls(Register(name=k, shape=v) for k, v in registers.items()) + def build(cls, **registers: int) -> 'Registers': + return cls(Register(name=k, bitsize=v) for k, v in registers.items()) @overload def __getitem__(self, key: int) -> Register: @@ -216,23 +222,29 @@ class SelectionRegister(Register): >>> assert len(flat_indices) == N * M * L """ + name: str + bitsize: int iteration_length: int = attr.field() + shape: Tuple[int, ...] = attr.field( + converter=lambda v: (v,) if isinstance(v, int) else tuple(v), default=() + ) @iteration_length.default def _default_iteration_length(self): - return 2 ** self.shape[0] + return 2**self.bitsize @iteration_length.validator def validate_iteration_length(self, attribute, value): - if len(self.shape) != 1: + if len(self.shape) != 0: raise ValueError(f'Selection register {self.name} should be flat. Found {self.shape=}') - if not (0 <= value <= 2 ** self.shape[0]): - raise ValueError(f'iteration length must be in range [0, 2^{self.shape[0]}]') + if not (0 <= value <= 2**self.bitsize): + raise ValueError(f'iteration length must be in range [0, 2^{self.bitsize}]') def __repr__(self) -> str: return ( f'cirq_ft.SelectionRegister(' f'name="{self.name}", ' + f'bitsize={self.bitsize}, ' f'shape={self.shape}, ' f'iteration_length={self.iteration_length})' ) diff --git a/cirq-ft/cirq_ft/infra/gate_with_registers_test.py b/cirq-ft/cirq_ft/infra/gate_with_registers_test.py index 7560cb7a357..57af2354e48 100644 --- a/cirq-ft/cirq_ft/infra/gate_with_registers_test.py +++ b/cirq-ft/cirq_ft/infra/gate_with_registers_test.py @@ -21,8 +21,9 @@ def test_register(): - r = cirq_ft.Register("my_reg", 5) - assert r.shape == (5,) + r = cirq_ft.Register("my_reg", 5, (1, 2)) + assert r.bitsize == 5 + assert r.shape == (1, 2) def test_registers(): @@ -103,12 +104,12 @@ def test_selection_registers_consistent(): _ = cirq_ft.SelectionRegister('a', 3, 10) with pytest.raises(ValueError, match="should be flat"): - _ = cirq_ft.SelectionRegister('a', (3, 5), 5) + _ = cirq_ft.SelectionRegister('a', bitsize=1, shape=(3, 5), iteration_length=5) selection_reg = cirq_ft.Registers( [ - cirq_ft.SelectionRegister('n', shape=3, iteration_length=5), - cirq_ft.SelectionRegister('m', shape=4, iteration_length=12), + cirq_ft.SelectionRegister('n', bitsize=3, iteration_length=5), + cirq_ft.SelectionRegister('m', bitsize=4, iteration_length=12), ] ) assert selection_reg[0] == cirq_ft.SelectionRegister('n', 3, 5) @@ -122,7 +123,9 @@ def test_registers_getitem_raises(): with pytest.raises(IndexError, match="must be of the type"): _ = g[2.5] - selection_reg = cirq_ft.Registers([cirq_ft.SelectionRegister('n', shape=3, iteration_length=5)]) + selection_reg = cirq_ft.Registers( + [cirq_ft.SelectionRegister('n', bitsize=3, iteration_length=5)] + ) with pytest.raises(IndexError, match='must be of the type'): _ = selection_reg[2.5] From acbc6247df69ca0d894178dc41f4459e1185d722 Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Mon, 25 Sep 2023 11:36:15 -0700 Subject: [PATCH 15/19] Do not allow creating registers with bitsize 0 (#6298) * Do not allow creating registers with bitsize 0 * Fix mypy errors --- cirq-ft/cirq_ft/algos/and_gate.py | 6 +++++- cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.py | 2 +- .../cirq_ft/algos/prepare_uniform_superposition.py | 2 +- cirq-ft/cirq_ft/algos/qrom.py | 6 ++++-- cirq-ft/cirq_ft/algos/selected_majorana_fermion.py | 2 +- cirq-ft/cirq_ft/algos/state_preparation.py | 2 +- cirq-ft/cirq_ft/infra/gate_with_registers.py | 11 ++++++++--- cirq-ft/cirq_ft/infra/gate_with_registers_test.py | 5 ++++- 8 files changed, 25 insertions(+), 11 deletions(-) diff --git a/cirq-ft/cirq_ft/algos/and_gate.py b/cirq-ft/cirq_ft/algos/and_gate.py index f308926d632..b34f632f5ff 100644 --- a/cirq-ft/cirq_ft/algos/and_gate.py +++ b/cirq-ft/cirq_ft/algos/and_gate.py @@ -141,7 +141,11 @@ def _decompose_via_tree( def decompose_from_registers( self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] ) -> cirq.OP_TREE: - control, ancilla, target = quregs['control'], quregs['ancilla'], quregs['target'] + control, ancilla, target = ( + quregs['control'], + quregs.get('ancilla', np.array([])), + quregs['target'], + ) if len(self.cv) == 2: yield self._decompose_single_and( self.cv[0], self.cv[1], control[0], control[1], *target diff --git a/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.py b/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.py index e3bb08be143..25f80dfa00d 100644 --- a/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.py +++ b/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.py @@ -77,7 +77,7 @@ def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: @cached_property def target_registers(self) -> Tuple[infra.Register, ...]: - total_iteration_size = np.product( + total_iteration_size = np.prod( tuple(reg.iteration_length for reg in self.selection_registers) ) return (infra.Register('target', int(total_iteration_size)),) diff --git a/cirq-ft/cirq_ft/algos/prepare_uniform_superposition.py b/cirq-ft/cirq_ft/algos/prepare_uniform_superposition.py index 374415e90bc..6497e3d65c5 100644 --- a/cirq-ft/cirq_ft/algos/prepare_uniform_superposition.py +++ b/cirq-ft/cirq_ft/algos/prepare_uniform_superposition.py @@ -69,7 +69,7 @@ def decompose_from_registers( context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid], # type:ignore[type-var] ) -> cirq.OP_TREE: - controls, target = quregs['controls'], quregs['target'] + controls, target = quregs.get('controls', ()), quregs['target'] # Find K and L as per https://arxiv.org/abs/1805.03662 Fig 12. n, k = self.n, 0 while n > 1 and n % 2 == 0: diff --git a/cirq-ft/cirq_ft/algos/qrom.py b/cirq-ft/cirq_ft/algos/qrom.py index 8d09d82ed9b..9feb90ad125 100644 --- a/cirq-ft/cirq_ft/algos/qrom.py +++ b/cirq-ft/cirq_ft/algos/qrom.py @@ -111,7 +111,9 @@ def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: @cached_property def target_registers(self) -> Tuple[infra.Register, ...]: - return tuple(infra.Register(f'target{i}', l) for i, l in enumerate(self.target_bitsizes)) + return tuple( + infra.Register(f'target{i}', l) for i, l in enumerate(self.target_bitsizes) if l + ) def __repr__(self) -> str: data_repr = f"({','.join(cirq._compat.proper_repr(d) for d in self.data)})" @@ -129,7 +131,7 @@ def _load_nth_data( **target_regs: NDArray[cirq.Qid], # type: ignore[type-var] ) -> cirq.OP_TREE: for i, d in enumerate(self.data): - target = target_regs[f'target{i}'] + target = target_regs.get(f'target{i}', ()) for q, bit in zip(target, f'{int(d[selection_idx]):0{len(target)}b}'): if int(bit): yield gate(q) diff --git a/cirq-ft/cirq_ft/algos/selected_majorana_fermion.py b/cirq-ft/cirq_ft/algos/selected_majorana_fermion.py index 877c81f39a3..a97eb752adb 100644 --- a/cirq-ft/cirq_ft/algos/selected_majorana_fermion.py +++ b/cirq-ft/cirq_ft/algos/selected_majorana_fermion.py @@ -77,7 +77,7 @@ def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: @cached_property def target_registers(self) -> Tuple[infra.Register, ...]: - total_iteration_size = np.product( + total_iteration_size = np.prod( tuple(reg.iteration_length for reg in self.selection_registers) ) return (infra.Register('target', int(total_iteration_size)),) diff --git a/cirq-ft/cirq_ft/algos/state_preparation.py b/cirq-ft/cirq_ft/algos/state_preparation.py index bec54f50a6b..aa660b5ebf8 100644 --- a/cirq-ft/cirq_ft/algos/state_preparation.py +++ b/cirq-ft/cirq_ft/algos/state_preparation.py @@ -167,7 +167,7 @@ def decompose_from_registers( **quregs: NDArray[cirq.Qid], # type:ignore[type-var] ) -> cirq.OP_TREE: selection, less_than_equal = quregs['selection'], quregs['less_than_equal'] - sigma_mu, alt, keep = quregs['sigma_mu'], quregs['alt'], quregs['keep'] + sigma_mu, alt, keep = quregs.get('sigma_mu', ()), quregs['alt'], quregs.get('keep', ()) N = self.selection_registers[0].iteration_length yield prepare_uniform_superposition.PrepareUniformSuperposition(N).on(*selection) yield cirq.H.on_each(*sigma_mu) diff --git a/cirq-ft/cirq_ft/infra/gate_with_registers.py b/cirq-ft/cirq_ft/infra/gate_with_registers.py index b4567591c67..624397ab479 100644 --- a/cirq-ft/cirq_ft/infra/gate_with_registers.py +++ b/cirq-ft/cirq_ft/infra/gate_with_registers.py @@ -32,11 +32,16 @@ class Register: """ name: str - bitsize: int + bitsize: int = attr.field() shape: Tuple[int, ...] = attr.field( converter=lambda v: (v,) if isinstance(v, int) else tuple(v), default=() ) + @bitsize.validator + def bitsize_validator(self, attribute, value): + if value <= 0: + raise ValueError(f"Bitsize for {self=} must be a positive integer. Found {value}.") + def all_idxs(self) -> Iterable[Tuple[int, ...]]: """Iterate over all possible indices of a multidimensional register.""" yield from itertools.product(*[range(sh) for sh in self.shape]) @@ -46,7 +51,7 @@ def total_bits(self) -> int: This is the product of each of the dimensions in `shape`. """ - return self.bitsize * int(np.product(self.shape)) + return self.bitsize * int(np.prod(self.shape)) def __repr__(self): return f'cirq_ft.Register(name="{self.name}", bitsize={self.bitsize}, shape={self.shape})' @@ -137,7 +142,7 @@ def __repr__(self): @classmethod def build(cls, **registers: int) -> 'Registers': - return cls(Register(name=k, bitsize=v) for k, v in registers.items()) + return cls(Register(name=k, bitsize=v) for k, v in registers.items() if v > 0) @overload def __getitem__(self, key: int) -> Register: diff --git a/cirq-ft/cirq_ft/infra/gate_with_registers_test.py b/cirq-ft/cirq_ft/infra/gate_with_registers_test.py index 57af2354e48..77e60aacbe8 100644 --- a/cirq-ft/cirq_ft/infra/gate_with_registers_test.py +++ b/cirq-ft/cirq_ft/infra/gate_with_registers_test.py @@ -25,6 +25,9 @@ def test_register(): assert r.bitsize == 5 assert r.shape == (1, 2) + with pytest.raises(ValueError, match="must be a positive integer"): + _ = cirq_ft.Register("zero bitsize register", bitsize=0) + def test_registers(): r1 = cirq_ft.Register("r1", 5) @@ -96,7 +99,7 @@ def test_selection_registers_indexing(n, N, m, M): assert np.ravel_multi_index((x, y), (N, M)) == x * M + y assert np.unravel_index(x * M + y, (N, M)) == (x, y) - assert np.product(tuple(reg.iteration_length for reg in regs)) == N * M + assert np.prod(tuple(reg.iteration_length for reg in regs)) == N * M def test_selection_registers_consistent(): From 61d967112ba23cc839b0e922bd42878024a3e738 Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Mon, 25 Sep 2023 21:01:17 +0000 Subject: [PATCH 16/19] Fix stream_manager_test type warnings (#6299) Co-authored-by: Tanuj Khattar --- cirq-google/cirq_google/engine/stream_manager_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cirq-google/cirq_google/engine/stream_manager_test.py b/cirq-google/cirq_google/engine/stream_manager_test.py index 42e6defbcc8..635f94eee15 100644 --- a/cirq-google/cirq_google/engine/stream_manager_test.py +++ b/cirq-google/cirq_google/engine/stream_manager_test.py @@ -70,12 +70,12 @@ class FakeQuantumRunStream: _REQUEST_STOPPED = 'REQUEST_STOPPED' - def __init__(self): + def __init__(self) -> None: self.all_stream_requests: List[quantum.QuantumRunStreamRequest] = [] self.all_cancel_requests: List[quantum.CancelQuantumJobRequest] = [] self._executor = AsyncioExecutor.instance() self._request_buffer = duet.AsyncCollector[quantum.QuantumRunStreamRequest]() - self._request_iterator_stopped = duet.AwaitableFuture() + self._request_iterator_stopped: duet.AwaitableFuture[None] = duet.AwaitableFuture() # asyncio.Queue needs to be initialized inside the asyncio thread because all callers need # to use the same event loop. self._responses_and_exceptions_future: duet.AwaitableFuture[ From c69632346948a802199685ead3cac50a5708905e Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Mon, 25 Sep 2023 15:35:52 -0700 Subject: [PATCH 17/19] Rename `cirq_ft.Registers` to `cirq_ft.Signature` to match data types in Qualtran (#6300) * Rename cirq_ft.Registers to cirq_ft.Signature to match data types in Qualtran * Fix coverage tests --------- Co-authored-by: Fionn Malone --- cirq-ft/cirq_ft/__init__.py | 2 +- cirq-ft/cirq_ft/algos/and_gate.ipynb | 2 +- cirq-ft/cirq_ft/algos/and_gate.py | 4 +- cirq-ft/cirq_ft/algos/and_gate_test.py | 10 +- .../algos/apply_gate_to_lth_target.ipynb | 6 +- .../cirq_ft/algos/apply_gate_to_lth_target.py | 4 +- .../algos/apply_gate_to_lth_target_test.py | 12 +-- cirq-ft/cirq_ft/algos/arithmetic_gates.py | 16 +-- cirq-ft/cirq_ft/algos/generic_select_test.py | 2 +- cirq-ft/cirq_ft/algos/hubbard_model.ipynb | 2 +- cirq-ft/cirq_ft/algos/hubbard_model.py | 32 +++--- cirq-ft/cirq_ft/algos/hubbard_model_test.py | 2 +- .../mean_estimation/complex_phase_oracle.py | 4 +- .../mean_estimation_operator.py | 10 +- .../mean_estimation_operator_test.py | 4 +- .../algos/multi_control_multi_target_pauli.py | 8 +- .../phase_estimation_of_quantum_walk.ipynb | 4 +- .../algos/prepare_uniform_superposition.py | 8 +- .../prepare_uniform_superposition_test.py | 2 +- .../algos/programmable_rotation_gate_array.py | 4 +- .../programmable_rotation_gate_array_test.py | 9 +- cirq-ft/cirq_ft/algos/qrom.ipynb | 6 +- cirq-ft/cirq_ft/algos/qrom.py | 8 +- cirq-ft/cirq_ft/algos/qrom_test.py | 6 +- .../algos/qubitization_walk_operator.py | 10 +- .../algos/qubitization_walk_operator_test.py | 4 +- .../cirq_ft/algos/reflection_using_prepare.py | 4 +- .../algos/reflection_using_prepare_test.py | 8 +- cirq-ft/cirq_ft/algos/select_and_prepare.py | 8 +- cirq-ft/cirq_ft/algos/select_swap_qrom.py | 10 +- .../cirq_ft/algos/select_swap_qrom_test.py | 4 +- .../algos/selected_majorana_fermion.py | 4 +- .../algos/selected_majorana_fermion_test.py | 10 +- cirq-ft/cirq_ft/algos/state_preparation.ipynb | 6 +- cirq-ft/cirq_ft/algos/state_preparation.py | 10 +- cirq-ft/cirq_ft/algos/swap_network.ipynb | 2 +- cirq-ft/cirq_ft/algos/swap_network.py | 10 +- cirq-ft/cirq_ft/algos/swap_network_test.py | 2 +- cirq-ft/cirq_ft/algos/unary_iteration.ipynb | 16 +-- cirq-ft/cirq_ft/algos/unary_iteration_gate.py | 16 +-- .../algos/unary_iteration_gate_test.py | 4 +- cirq-ft/cirq_ft/infra/__init__.py | 3 +- .../cirq_ft/infra/gate_with_registers.ipynb | 18 ++-- cirq-ft/cirq_ft/infra/gate_with_registers.py | 99 ++++++++++++------- .../cirq_ft/infra/gate_with_registers_test.py | 47 ++++----- cirq-ft/cirq_ft/infra/jupyter_tools.py | 8 +- cirq-ft/cirq_ft/infra/t_complexity.ipynb | 8 +- .../infra/t_complexity_protocol_test.py | 8 +- cirq-ft/cirq_ft/infra/testing.py | 8 +- cirq-ft/cirq_ft/infra/testing_test.py | 2 +- 50 files changed, 269 insertions(+), 227 deletions(-) diff --git a/cirq-ft/cirq_ft/__init__.py b/cirq-ft/cirq_ft/__init__.py index 47bf47cf660..48a7e334ef2 100644 --- a/cirq-ft/cirq_ft/__init__.py +++ b/cirq-ft/cirq_ft/__init__.py @@ -47,7 +47,7 @@ GateWithRegisters, GreedyQubitManager, Register, - Registers, + Signature, SelectionRegister, TComplexity, map_clean_and_borrowable_qubits, diff --git a/cirq-ft/cirq_ft/algos/and_gate.ipynb b/cirq-ft/cirq_ft/algos/and_gate.ipynb index 498081f0cb4..1c47eec56ab 100644 --- a/cirq-ft/cirq_ft/algos/and_gate.ipynb +++ b/cirq-ft/cirq_ft/algos/and_gate.ipynb @@ -66,7 +66,7 @@ "from cirq_ft import And, infra\n", "\n", "gate = And()\n", - "r = gate.registers\n", + "r = gate.signature\n", "quregs = infra.get_named_qubits(r)\n", "operation = gate.on_registers(**quregs)\n", "circuit = cirq.Circuit(operation)\n", diff --git a/cirq-ft/cirq_ft/algos/and_gate.py b/cirq-ft/cirq_ft/algos/and_gate.py index b34f632f5ff..1d54d1a66b4 100644 --- a/cirq-ft/cirq_ft/algos/and_gate.py +++ b/cirq-ft/cirq_ft/algos/and_gate.py @@ -60,8 +60,8 @@ def _validate_cv(self, attribute, value): raise ValueError(f"And gate needs at-least 2 control values, supplied {value} instead.") @cached_property - def registers(self) -> infra.Registers: - return infra.Registers.build(control=len(self.cv), ancilla=len(self.cv) - 2, target=1) + def signature(self) -> infra.Signature: + return infra.Signature.build(control=len(self.cv), ancilla=len(self.cv) - 2, target=1) def __pow__(self, power: int) -> "And": if power == 1: diff --git a/cirq-ft/cirq_ft/algos/and_gate_test.py b/cirq-ft/cirq_ft/algos/and_gate_test.py index 70de51a205b..f79962cfdfa 100644 --- a/cirq-ft/cirq_ft/algos/and_gate_test.py +++ b/cirq-ft/cirq_ft/algos/and_gate_test.py @@ -45,17 +45,17 @@ def random_cv(n: int) -> List[int]: @pytest.mark.parametrize("cv", [[1] * 3, random_cv(5), random_cv(6), random_cv(7)]) def test_multi_controlled_and_gate(cv: List[int]): gate = cirq_ft.And(cv) - r = gate.registers - assert r['ancilla'].total_bits() == r['control'].total_bits() - 2 + r = gate.signature + assert r.get_left('ancilla').total_bits() == r.get_left('control').total_bits() - 2 quregs = infra.get_named_qubits(r) and_op = gate.on_registers(**quregs) circuit = cirq.Circuit(and_op) input_controls = [cv] + [random_cv(len(cv)) for _ in range(10)] - qubit_order = infra.merge_qubits(gate.registers, **quregs) + qubit_order = infra.merge_qubits(gate.signature, **quregs) for input_control in input_controls: - initial_state = input_control + [0] * (r['ancilla'].total_bits() + 1) + initial_state = input_control + [0] * (r.get_left('ancilla').total_bits() + 1) result = cirq.Simulator(dtype=np.complex128).simulate( circuit, initial_state=initial_state, qubit_order=qubit_order ) @@ -78,7 +78,7 @@ def test_multi_controlled_and_gate(cv: List[int]): def test_and_gate_diagram(): gate = cirq_ft.And((1, 0, 1, 0, 1, 0)) - qubit_regs = infra.get_named_qubits(gate.registers) + qubit_regs = infra.get_named_qubits(gate.signature) op = gate.on_registers(**qubit_regs) # Qubit order should be alternating (control, ancilla) pairs. c_and_a = sum(zip(qubit_regs["control"][1:], qubit_regs["ancilla"]), ()) + ( diff --git a/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.ipynb b/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.ipynb index 3e34704c619..236832f691c 100644 --- a/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.ipynb +++ b/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.ipynb @@ -69,9 +69,9 @@ "`selection`-th qubit of `target` all controlled by the `control` register.\n", "\n", "#### Parameters\n", - " - `selection_regs`: Indexing `select` registers of type Tuple[`SelectionRegister`, ...]. It also contains information about the iteration length of each selection register.\n", + " - `selection_regs`: Indexing `select` signature of type Tuple[`SelectionRegister`, ...]. It also contains information about the iteration length of each selection register.\n", " - `nth_gate`: A function mapping the composite selection index to a single-qubit gate.\n", - " - `control_regs`: Control registers for constructing a controlled version of the gate.\n" + " - `control_regs`: Control signature for constructing a controlled version of the gate.\n" ] }, { @@ -91,7 +91,7 @@ "apply_z_to_odd = cirq_ft.ApplyGateToLthQubit(\n", " cirq_ft.SelectionRegister('selection', 3, 4),\n", " nth_gate=_z_to_odd,\n", - " control_regs=cirq_ft.Registers.build(control=2),\n", + " control_regs=cirq_ft.Signature.build(control=2),\n", ")\n", "\n", "g = cq_testing.GateHelper(\n", diff --git a/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.py b/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.py index 25f80dfa00d..5c888b2f5a3 100644 --- a/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.py +++ b/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.py @@ -37,10 +37,10 @@ class ApplyGateToLthQubit(unary_iteration_gate.UnaryIterationGate): `selection`-th qubit of `target` all controlled by the `control` register. Args: - selection_regs: Indexing `select` registers of type Tuple[`SelectionRegisters`, ...]. + selection_regs: Indexing `select` signature of type Tuple[`SelectionRegisters`, ...]. It also contains information about the iteration length of each selection register. nth_gate: A function mapping the composite selection index to a single-qubit gate. - control_regs: Control registers for constructing a controlled version of the gate. + control_regs: Control signature for constructing a controlled version of the gate. References: [Encoding Electronic Spectra in Quantum Circuits with Linear T Complexity] diff --git a/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target_test.py b/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target_test.py index 2c2e29e7c0c..074418048a5 100644 --- a/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target_test.py +++ b/cirq-ft/cirq_ft/algos/apply_gate_to_lth_target_test.py @@ -54,10 +54,10 @@ def test_apply_gate_to_lth_qubit_diagram(): gate = cirq_ft.ApplyGateToLthQubit( cirq_ft.SelectionRegister('selection', 3, 5), lambda n: cirq.Z if n & 1 else cirq.I, - control_regs=cirq_ft.Registers.build(control=2), + control_regs=cirq_ft.Signature.build(control=2), ) - circuit = cirq.Circuit(gate.on_registers(**infra.get_named_qubits(gate.registers))) - qubits = list(q for v in infra.get_named_qubits(gate.registers).values() for q in v) + circuit = cirq.Circuit(gate.on_registers(**infra.get_named_qubits(gate.signature))) + qubits = list(q for v in infra.get_named_qubits(gate.signature).values() for q in v) cirq.testing.assert_has_diagram( circuit, """ @@ -89,11 +89,11 @@ def test_apply_gate_to_lth_qubit_make_on(): gate = cirq_ft.ApplyGateToLthQubit( cirq_ft.SelectionRegister('selection', 3, 5), lambda n: cirq.Z if n & 1 else cirq.I, - control_regs=cirq_ft.Registers.build(control=2), + control_regs=cirq_ft.Signature.build(control=2), ) - op = gate.on_registers(**infra.get_named_qubits(gate.registers)) + op = gate.on_registers(**infra.get_named_qubits(gate.signature)) op2 = cirq_ft.ApplyGateToLthQubit.make_on( - nth_gate=lambda n: cirq.Z if n & 1 else cirq.I, **infra.get_named_qubits(gate.registers) + nth_gate=lambda n: cirq.Z if n & 1 else cirq.I, **infra.get_named_qubits(gate.signature) ) # Note: ApplyGateToLthQubit doesn't support value equality. assert op.qubits == op2.qubits diff --git a/cirq-ft/cirq_ft/algos/arithmetic_gates.py b/cirq-ft/cirq_ft/algos/arithmetic_gates.py index 6054c90709f..1a5f03b6ccf 100644 --- a/cirq-ft/cirq_ft/algos/arithmetic_gates.py +++ b/cirq-ft/cirq_ft/algos/arithmetic_gates.py @@ -137,7 +137,7 @@ class BiQubitsMixer(infra.GateWithRegisters): """Implements the COMPARE2 (Fig. 1) https://static-content.springer.com/esm/art%3A10.1038%2Fs41534-018-0071-5/MediaObjects/41534_2018_71_MOESM1_ESM.pdf This gates mixes the values in a way that preserves the result of comparison. - The registers being compared are 2-qubit registers where + The signature being compared are 2-qubit signature where x = 2*x_msb + x_lsb y = 2*y_msb + y_lsb The Gate mixes the 4 qubits so that sign(x - y) = sign(x_lsb' - y_lsb') where x_lsb' and y_lsb' @@ -147,8 +147,8 @@ class BiQubitsMixer(infra.GateWithRegisters): adjoint: bool = False @cached_property - def registers(self) -> infra.Registers: - return infra.Registers.build(x=2, y=2, ancilla=3) + def signature(self) -> infra.Signature: + return infra.Signature.build(x=2, y=2, ancilla=3) def __repr__(self) -> str: return f'cirq_ft.algos.BiQubitsMixer({self.adjoint})' @@ -221,8 +221,8 @@ class SingleQubitCompare(infra.GateWithRegisters): adjoint: bool = False @cached_property - def registers(self) -> infra.Registers: - return infra.Registers.build(a=1, b=1, less_than=1, greater_than=1) + def signature(self) -> infra.Signature: + return infra.Signature.build(a=1, b=1, less_than=1, greater_than=1) def __repr__(self) -> str: return f'cirq_ft.algos.SingleQubitCompare({self.adjoint})' @@ -437,7 +437,7 @@ def _has_unitary_(self): class ContiguousRegisterGate(cirq.ArithmeticGate): """Applies U|x>|y>|0> -> |x>|y>|x(x-1)/2 + y> - This is useful in the case when $|x>$ and $|y>$ represent two selection registers such that + This is useful in the case when $|x>$ and $|y>$ represent two selection signature such that $y < x$. For example, imagine a classical for-loop over two variables $x$ and $y$: >>> N = 10 @@ -460,8 +460,8 @@ class ContiguousRegisterGate(cirq.ArithmeticGate): Note that both the for-loops iterate over the same ranges and in the same order. The only difference is that the second loop is a "flattened" version of the first one. - Such a flattening of selection registers is useful when we want to load multi dimensional - data to a target register which is indexed on selection registers $x$ and $y$ such that + Such a flattening of selection signature is useful when we want to load multi dimensional + data to a target register which is indexed on selection signature $x$ and $y$ such that $0<= y <= x < N$ and we want to use a `SelectSwapQROM` to laod this data; which gives a sqrt-speedup over a traditional QROM at the cost of using more memory and loading chunks of size `sqrt(N)` in a single iteration. See the reference for more details. diff --git a/cirq-ft/cirq_ft/algos/generic_select_test.py b/cirq-ft/cirq_ft/algos/generic_select_test.py index 255e9ba6b79..8f454857e07 100644 --- a/cirq-ft/cirq_ft/algos/generic_select_test.py +++ b/cirq-ft/cirq_ft/algos/generic_select_test.py @@ -256,7 +256,7 @@ def test_generic_select_consistent_protocols_and_controlled(): # Build GenericSelect gate. gate = cirq_ft.GenericSelect(select_bitsize, num_sites, dps_hamiltonian) - op = gate.on_registers(**infra.get_named_qubits(gate.registers)) + op = gate.on_registers(**infra.get_named_qubits(gate.signature)) cirq.testing.assert_equivalent_repr(gate, setup_code='import cirq\nimport cirq_ft') # Build controlled gate diff --git a/cirq-ft/cirq_ft/algos/hubbard_model.ipynb b/cirq-ft/cirq_ft/algos/hubbard_model.ipynb index bbecd54cf8b..99f17654978 100644 --- a/cirq-ft/cirq_ft/algos/hubbard_model.ipynb +++ b/cirq-ft/cirq_ft/algos/hubbard_model.ipynb @@ -171,7 +171,7 @@ " logN = (2 * (dim - 1).bit_length() + 1)\n", " # 2 * (4 * (N - 1)) : From 2 SelectMajoranaFermion gates.\n", " # 4 * (N/2) : From 1 mulit-controlled ApplyToLthQubit gate on N / 2 targets.\n", - " # 2 * 7 * logN : From 2 CSWAPS on logN qubits corresponding to (p, q) select registers.\n", + " # 2 * 7 * logN : From 2 CSWAPS on logN qubits corresponding to (p, q) select signature.\n", " assert cost.t == 10 * N + 14 * logN - 8\n", " assert cost.rotations == 0\n", " x.append(N)\n", diff --git a/cirq-ft/cirq_ft/algos/hubbard_model.py b/cirq-ft/cirq_ft/algos/hubbard_model.py index dad2a443c9a..8c9a450c5d2 100644 --- a/cirq-ft/cirq_ft/algos/hubbard_model.py +++ b/cirq-ft/cirq_ft/algos/hubbard_model.py @@ -93,7 +93,7 @@ class SelectHubbard(select_and_prepare.SelectOracle): control_val: Optional bit specifying the control value for constructing a controlled version of this gate. Defaults to None, which means no control. - Registers: + Signature: control: A control bit for the entire gate. U: Whether we're applying the single-site part of the potential. V: Whether we're applying the pairwise part of the potential. @@ -139,8 +139,8 @@ def target_registers(self) -> Tuple[infra.Register, ...]: return (infra.Register('target', self.x_dim * self.y_dim * 2),) @cached_property - def registers(self) -> infra.Registers: - return infra.Registers( + def signature(self) -> infra.Signature: + return infra.Signature( [*self.control_registers, *self.selection_registers, *self.target_registers] ) @@ -157,8 +157,12 @@ def decompose_from_registers( yield selected_majorana_fermion.SelectedMajoranaFermionGate( selection_regs=( infra.SelectionRegister('alpha', 1, 2), - infra.SelectionRegister('p_y', self.registers['p_y'].total_bits(), self.y_dim), - infra.SelectionRegister('p_x', self.registers['p_x'].total_bits(), self.x_dim), + infra.SelectionRegister( + 'p_y', self.signature.get_left('p_y').total_bits(), self.y_dim + ), + infra.SelectionRegister( + 'p_x', self.signature.get_left('p_x').total_bits(), self.x_dim + ), ), control_regs=self.control_registers, target_gate=cirq.Y, @@ -170,8 +174,8 @@ def decompose_from_registers( q_selection_regs = ( infra.SelectionRegister('beta', 1, 2), - infra.SelectionRegister('q_y', self.registers['q_y'].total_bits(), self.y_dim), - infra.SelectionRegister('q_x', self.registers['q_x'].total_bits(), self.x_dim), + infra.SelectionRegister('q_y', self.signature.get_left('q_y').total_bits(), self.y_dim), + infra.SelectionRegister('q_x', self.signature.get_left('q_x').total_bits(), self.x_dim), ) yield selected_majorana_fermion.SelectedMajoranaFermionGate( selection_regs=q_selection_regs, control_regs=self.control_registers, target_gate=cirq.X @@ -194,8 +198,12 @@ def decompose_from_registers( yield apply_gate_to_lth_target.ApplyGateToLthQubit( selection_regs=( - infra.SelectionRegister('q_y', self.registers['q_y'].total_bits(), self.y_dim), - infra.SelectionRegister('q_x', self.registers['q_x'].total_bits(), self.x_dim), + infra.SelectionRegister( + 'q_y', self.signature.get_left('q_y').total_bits(), self.y_dim + ), + infra.SelectionRegister( + 'q_x', self.signature.get_left('q_x').total_bits(), self.x_dim + ), ), nth_gate=lambda *_: cirq.Z, control_regs=infra.Register('control', 1 + infra.total_bits(self.control_registers)), @@ -255,7 +263,7 @@ class PrepareHubbard(select_and_prepare.PrepareOracle): mu: coefficient for single body Z term and two-body ZZ terms in the Hubbard model hamiltonian. - Registers: + Signature: control: A control bit for the entire gate. U: Whether we're applying the single-site part of the potential. V: Whether we're applying the pairwise part of the potential. @@ -299,8 +307,8 @@ def junk_registers(self) -> Tuple[infra.Register, ...]: return (infra.Register('temp', 2),) @cached_property - def registers(self) -> infra.Registers: - return infra.Registers([*self.selection_registers, *self.junk_registers]) + def signature(self) -> infra.Signature: + return infra.Signature([*self.selection_registers, *self.junk_registers]) def decompose_from_registers( self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] diff --git a/cirq-ft/cirq_ft/algos/hubbard_model_test.py b/cirq-ft/cirq_ft/algos/hubbard_model_test.py index b13f9e6dfd6..21770379e86 100644 --- a/cirq-ft/cirq_ft/algos/hubbard_model_test.py +++ b/cirq-ft/cirq_ft/algos/hubbard_model_test.py @@ -49,7 +49,7 @@ def test_hubbard_model_consistent_protocols(): cirq.testing.assert_equivalent_repr(prepare_gate, setup_code='import cirq_ft') # Build controlled SELECT gate - select_op = select_gate.on_registers(**infra.get_named_qubits(select_gate.registers)) + select_op = select_gate.on_registers(**infra.get_named_qubits(select_gate.signature)) equals_tester = cirq.testing.EqualsTester() equals_tester.add_equality_group( select_gate.controlled(), diff --git a/cirq-ft/cirq_ft/algos/mean_estimation/complex_phase_oracle.py b/cirq-ft/cirq_ft/algos/mean_estimation/complex_phase_oracle.py index 695a51ba854..7fee906ccf7 100644 --- a/cirq-ft/cirq_ft/algos/mean_estimation/complex_phase_oracle.py +++ b/cirq-ft/cirq_ft/algos/mean_estimation/complex_phase_oracle.py @@ -46,8 +46,8 @@ def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: return self.encoder.selection_registers @cached_property - def registers(self) -> infra.Registers: - return infra.Registers([*self.control_registers, *self.selection_registers]) + def signature(self) -> infra.Signature: + return infra.Signature([*self.control_registers, *self.selection_registers]) def decompose_from_registers( self, diff --git a/cirq-ft/cirq_ft/algos/mean_estimation/mean_estimation_operator.py b/cirq-ft/cirq_ft/algos/mean_estimation/mean_estimation_operator.py index 40de332dad1..3804cf9eee6 100644 --- a/cirq-ft/cirq_ft/algos/mean_estimation/mean_estimation_operator.py +++ b/cirq-ft/cirq_ft/algos/mean_estimation/mean_estimation_operator.py @@ -109,8 +109,8 @@ def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: return self.code.encoder.selection_registers @cached_property - def registers(self) -> infra.Registers: - return infra.Registers([*self.control_registers, *self.selection_registers]) + def signature(self) -> infra.Signature: + return infra.Signature([*self.control_registers, *self.selection_registers]) def decompose_from_registers( self, @@ -118,8 +118,8 @@ def decompose_from_registers( context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid], # type:ignore[type-var] ) -> cirq.OP_TREE: - select_reg = {reg.name: quregs[reg.name] for reg in self.select.registers} - reflect_reg = {reg.name: quregs[reg.name] for reg in self.reflect.registers} + select_reg = {reg.name: quregs[reg.name] for reg in self.select.signature} + reflect_reg = {reg.name: quregs[reg.name] for reg in self.reflect.signature} select_op = self.select.on_registers(**select_reg) reflect_op = self.reflect.on_registers(**reflect_reg) for _ in range(self.power): @@ -132,7 +132,7 @@ def decompose_from_registers( def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo: wire_symbols = [] if self.cv == () else [["@(0)", "@"][self.cv[0]]] wire_symbols += ['U_ko'] * ( - infra.total_bits(self.registers) - infra.total_bits(self.control_registers) + infra.total_bits(self.signature) - infra.total_bits(self.control_registers) ) if self.power != 1: wire_symbols[-1] = f'U_ko^{self.power}' diff --git a/cirq-ft/cirq_ft/algos/mean_estimation/mean_estimation_operator_test.py b/cirq-ft/cirq_ft/algos/mean_estimation/mean_estimation_operator_test.py index ef9596dd861..cf649e82d8e 100644 --- a/cirq-ft/cirq_ft/algos/mean_estimation/mean_estimation_operator_test.py +++ b/cirq-ft/cirq_ft/algos/mean_estimation/mean_estimation_operator_test.py @@ -117,7 +117,7 @@ def satisfies_theorem_321( assert cirq.is_unitary(u) # Compute the final state vector obtained using the synthesizer `Prep |0>` - prep_op = synthesizer.on_registers(**infra.get_named_qubits(synthesizer.registers)) + prep_op = synthesizer.on_registers(**infra.get_named_qubits(synthesizer.signature)) prep_state = cirq.Circuit(prep_op).final_state_vector() expected_hav = abs(mu) * np.sqrt(1 / (1 + s**2)) @@ -252,7 +252,7 @@ def test_mean_estimation_operator_consistent_protocols(): encoder = BernoulliEncoder(p, (0, y_1), selection_bitsize, target_bitsize) code = CodeForRandomVariable(synthesizer=synthesizer, encoder=encoder) mean_gate = MeanEstimationOperator(code, arctan_bitsize=arctan_bitsize) - op = mean_gate.on_registers(**infra.get_named_qubits(mean_gate.registers)) + op = mean_gate.on_registers(**infra.get_named_qubits(mean_gate.signature)) # Test controlled gate. equals_tester = cirq.testing.EqualsTester() diff --git a/cirq-ft/cirq_ft/algos/multi_control_multi_target_pauli.py b/cirq-ft/cirq_ft/algos/multi_control_multi_target_pauli.py index bb96215b729..e2ef065fb7d 100644 --- a/cirq-ft/cirq_ft/algos/multi_control_multi_target_pauli.py +++ b/cirq-ft/cirq_ft/algos/multi_control_multi_target_pauli.py @@ -35,8 +35,8 @@ def __init__(self, num_targets: int): self._num_targets = num_targets @cached_property - def registers(self) -> infra.Registers: - return infra.Registers.build(control=1, targets=self._num_targets) + def signature(self) -> infra.Signature: + return infra.Signature.build(control=1, targets=self._num_targets) def decompose_from_registers( self, @@ -77,8 +77,8 @@ class MultiControlPauli(infra.GateWithRegisters): target_gate: cirq.Pauli = cirq.X @cached_property - def registers(self) -> infra.Registers: - return infra.Registers.build(controls=len(self.cvs), target=1) + def signature(self) -> infra.Signature: + return infra.Signature.build(controls=len(self.cvs), target=1) def decompose_from_registers( self, *, context: cirq.DecompositionContext, **quregs: NDArray['cirq.Qid'] diff --git a/cirq-ft/cirq_ft/algos/phase_estimation_of_quantum_walk.ipynb b/cirq-ft/cirq_ft/algos/phase_estimation_of_quantum_walk.ipynb index 6121248b9fa..99ca0bed34e 100644 --- a/cirq-ft/cirq_ft/algos/phase_estimation_of_quantum_walk.ipynb +++ b/cirq-ft/cirq_ft/algos/phase_estimation_of_quantum_walk.ipynb @@ -88,8 +88,8 @@ " Fig. 2\n", " \"\"\"\n", " reflect = walk.reflect\n", - " walk_regs = infra.get_named_qubits(walk.registers)\n", - " reflect_regs = {k:v for k, v in walk_regs.items() if k in reflect.registers}\n", + " walk_regs = infra.get_named_qubits(walk.signature)\n", + " reflect_regs = {reg.name: walk_regs[reg.name] for reg in reflect.signature}\n", " \n", " reflect_controlled = reflect.controlled(control_values=[0])\n", " walk_controlled = walk.controlled(control_values=[1])\n", diff --git a/cirq-ft/cirq_ft/algos/prepare_uniform_superposition.py b/cirq-ft/cirq_ft/algos/prepare_uniform_superposition.py index 6497e3d65c5..db422d48958 100644 --- a/cirq-ft/cirq_ft/algos/prepare_uniform_superposition.py +++ b/cirq-ft/cirq_ft/algos/prepare_uniform_superposition.py @@ -51,15 +51,15 @@ class PrepareUniformSuperposition(infra.GateWithRegisters): ) @cached_property - def registers(self) -> infra.Registers: - return infra.Registers.build(controls=len(self.cv), target=(self.n - 1).bit_length()) + def signature(self) -> infra.Signature: + return infra.Signature.build(controls=len(self.cv), target=(self.n - 1).bit_length()) def __repr__(self) -> str: return f"cirq_ft.PrepareUniformSuperposition({self.n}, cv={self.cv})" def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo: control_symbols = ["@" if cv else "@(0)" for cv in self.cv] - target_symbols = ['target'] * self.registers['target'].total_bits() + target_symbols = ['target'] * self.signature.get_left('target').total_bits() target_symbols[0] = f"UNIFORM({self.n})" return cirq.CircuitDiagramInfo(wire_symbols=control_symbols + target_symbols) @@ -75,7 +75,7 @@ def decompose_from_registers( while n > 1 and n % 2 == 0: k += 1 n = n // 2 - l, logL = int(n), self.registers['target'].total_bits() - k + l, logL = int(n), self.signature.get_left('target').total_bits() - k logL_qubits = target[:logL] yield [ diff --git a/cirq-ft/cirq_ft/algos/prepare_uniform_superposition_test.py b/cirq-ft/cirq_ft/algos/prepare_uniform_superposition_test.py index f58cc671e63..9d56a9d2e27 100644 --- a/cirq-ft/cirq_ft/algos/prepare_uniform_superposition_test.py +++ b/cirq-ft/cirq_ft/algos/prepare_uniform_superposition_test.py @@ -52,7 +52,7 @@ def test_prepare_uniform_superposition_t_complexity(n: int): result = cirq_ft.t_complexity(gate) # TODO(#233): Controlled-H is currently counted as a separate rotation, but it can be # implemented using 2 T-gates. - assert result.rotations <= 2 + 2 * infra.total_bits(gate.registers) + assert result.rotations <= 2 + 2 * infra.total_bits(gate.signature) assert result.t <= 12 * (n - 1).bit_length() diff --git a/cirq-ft/cirq_ft/algos/programmable_rotation_gate_array.py b/cirq-ft/cirq_ft/algos/programmable_rotation_gate_array.py index 158ec1a112d..07bf220994e 100644 --- a/cirq-ft/cirq_ft/algos/programmable_rotation_gate_array.py +++ b/cirq-ft/cirq_ft/algos/programmable_rotation_gate_array.py @@ -119,8 +119,8 @@ def interleaved_unitary_target(self) -> Tuple[infra.Register, ...]: pass @cached_property - def registers(self) -> infra.Registers: - return infra.Registers( + def signature(self) -> infra.Signature: + return infra.Signature( [ *self.selection_registers, *self.kappa_load_target, diff --git a/cirq-ft/cirq_ft/algos/programmable_rotation_gate_array_test.py b/cirq-ft/cirq_ft/algos/programmable_rotation_gate_array_test.py index cdd4212bcc3..c1a275db06a 100644 --- a/cirq-ft/cirq_ft/algos/programmable_rotation_gate_array_test.py +++ b/cirq-ft/cirq_ft/algos/programmable_rotation_gate_array_test.py @@ -36,7 +36,7 @@ def interleaved_unitary( @cached_property def interleaved_unitary_target(self) -> Tuple[cirq_ft.Register, ...]: - return tuple(cirq_ft.Registers.build(unrelated_target=1)) + return tuple(cirq_ft.Signature.build(unrelated_target=1)) def construct_custom_prga(*args, **kwargs) -> cirq_ft.ProgrammableRotationGateArrayBase: @@ -74,7 +74,7 @@ def test_programmable_rotation_gate_array(angles, kappa, constructor): for i in range(len(angles) - 1) ] # Get qubits on which rotations + unitaries act. - rotations_and_unitary_registers = cirq_ft.Registers( + rotations_and_unitary_registers = cirq_ft.Signature( [ *programmable_rotation_gate.rotations_target, *programmable_rotation_gate.interleaved_unitary_target, @@ -95,7 +95,10 @@ def rotation_ops(theta: int) -> cirq.OP_TREE: # Set bits in initial_state s.t. selection register stores `selection_integer`. qubit_vals = {x: 0 for x in g.all_qubits} qubit_vals.update( - zip(g.quregs['selection'], iter_bits(selection_integer, g.r['selection'].total_bits())) + zip( + g.quregs['selection'], + iter_bits(selection_integer, g.r.get_left('selection').total_bits()), + ) ) initial_state = [qubit_vals[x] for x in g.all_qubits] # Actual circuit simulation. diff --git a/cirq-ft/cirq_ft/algos/qrom.ipynb b/cirq-ft/cirq_ft/algos/qrom.ipynb index 6ddae7f0537..2f37c079d54 100644 --- a/cirq-ft/cirq_ft/algos/qrom.ipynb +++ b/cirq-ft/cirq_ft/algos/qrom.ipynb @@ -60,13 +60,13 @@ "Gate to load data[l] in the target register when the selection stores an index l.\n", "\n", "In the case of multi-dimensional data[p,q,r,...] we use of multple name\n", - "selection registers [p, q, r, ...] to index and load the data.\n", + "selection signature [p, q, r, ...] to index and load the data.\n", "\n", "#### Parameters\n", " - `data`: List of numpy ndarrays specifying the data to load. If the length of this list is greater than one then we use the same selection indices to load each dataset (for example, to load alt and keep data for state preparation). Each data set is required to have the same shape and to be of integer type.\n", " - `selection_bitsizes`: The number of bits used to represent each selection register corresponding to the size of each dimension of the array. Should be the same length as the shape of each of the datasets.\n", - " - `target_bitsizes`: The number of bits used to represent the data registers. This can be deduced from the maximum element of each of the datasets. Should be of length len(data), i.e. the number of datasets.\n", - " - `num_controls`: The number of control registers.\n" + " - `target_bitsizes`: The number of bits used to represent the data signature. This can be deduced from the maximum element of each of the datasets. Should be of length len(data), i.e. the number of datasets.\n", + " - `num_controls`: The number of control signature.\n" ] }, { diff --git a/cirq-ft/cirq_ft/algos/qrom.py b/cirq-ft/cirq_ft/algos/qrom.py index 9feb90ad125..00a02d710cb 100644 --- a/cirq-ft/cirq_ft/algos/qrom.py +++ b/cirq-ft/cirq_ft/algos/qrom.py @@ -29,8 +29,8 @@ class QROM(unary_iteration_gate.UnaryIterationGate): """Gate to load data[l] in the target register when the selection stores an index l. In the case of multi-dimensional data[p,q,r,...] we use multiple named - selection registers [p, q, r, ...] to index and load the data. Here `p, q, r, ...` - correspond to registers named `selection0`, `selection1`, `selection2`, ... etc. + selection signature [p, q, r, ...] to index and load the data. Here `p, q, r, ...` + correspond to signature named `selection0`, `selection1`, `selection2`, ... etc. When the input data elements contain consecutive entries of identical data elements to load, the QROM also implements the "variable-spaced" QROM optimization described in Ref[2]. @@ -45,9 +45,9 @@ class QROM(unary_iteration_gate.UnaryIterationGate): corresponding to the size of each dimension of the array. Should be the same length as the shape of each of the datasets. target_bitsizes: The number of bits used to represent the data - registers. This can be deduced from the maximum element of each of the + signature. This can be deduced from the maximum element of each of the datasets. Should be of length len(data), i.e. the number of datasets. - num_controls: The number of control registers. + num_controls: The number of control signature. References: [Encoding Electronic Spectra in Quantum Circuits with Linear T Complexity] diff --git a/cirq-ft/cirq_ft/algos/qrom_test.py b/cirq-ft/cirq_ft/algos/qrom_test.py index 514e7f03935..a9a56c53ff4 100644 --- a/cirq-ft/cirq_ft/algos/qrom_test.py +++ b/cirq-ft/cirq_ft/algos/qrom_test.py @@ -36,7 +36,7 @@ def test_qrom_1d(data, num_controls): assert ( len(inverse.all_qubits()) - <= infra.total_bits(g.r) + g.r['selection'].total_bits() + num_controls + <= infra.total_bits(g.r) + g.r.get_left('selection').total_bits() + num_controls ) assert inverse.all_qubits() == decomposed_circuit.all_qubits() @@ -46,7 +46,7 @@ def test_qrom_1d(data, num_controls): qubit_vals.update( zip( g.quregs['selection'], - iter_bits(selection_integer, g.r['selection'].total_bits()), + iter_bits(selection_integer, g.r.get_left('selection').total_bits()), ) ) if num_controls: @@ -75,7 +75,7 @@ def test_qrom_diagram(): d1 = np.array([4, 5, 6]) qrom = cirq_ft.QROM.build(d0, d1) q = cirq.LineQubit.range(cirq.num_qubits(qrom)) - circuit = cirq.Circuit(qrom.on_registers(**infra.split_qubits(qrom.registers, q))) + circuit = cirq.Circuit(qrom.on_registers(**infra.split_qubits(qrom.signature, q))) cirq.testing.assert_has_diagram( circuit, """ diff --git a/cirq-ft/cirq_ft/algos/qubitization_walk_operator.py b/cirq-ft/cirq_ft/algos/qubitization_walk_operator.py index f39964af93f..81219e4fb20 100644 --- a/cirq-ft/cirq_ft/algos/qubitization_walk_operator.py +++ b/cirq-ft/cirq_ft/algos/qubitization_walk_operator.py @@ -73,8 +73,8 @@ def target_registers(self) -> Tuple[infra.Register, ...]: return self.select.target_registers @cached_property - def registers(self) -> infra.Registers: - return infra.Registers( + def signature(self) -> infra.Signature: + return infra.Signature( [*self.control_registers, *self.selection_registers, *self.target_registers] ) @@ -89,10 +89,10 @@ def decompose_from_registers( context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid], # type:ignore[type-var] ) -> cirq.OP_TREE: - select_reg = {reg.name: quregs[reg.name] for reg in self.select.registers} + select_reg = {reg.name: quregs[reg.name] for reg in self.select.signature} select_op = self.select.on_registers(**select_reg) - reflect_reg = {reg.name: quregs[reg.name] for reg in self.reflect.registers} + reflect_reg = {reg.name: quregs[reg.name] for reg in self.reflect.signature} reflect_op = self.reflect.on_registers(**reflect_reg) for _ in range(self.power): yield select_op @@ -103,7 +103,7 @@ def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.Circ self.control_registers ) wire_symbols += ['W'] * ( - infra.total_bits(self.registers) - infra.total_bits(self.control_registers) + infra.total_bits(self.signature) - infra.total_bits(self.control_registers) ) wire_symbols[-1] = f'W^{self.power}' if self.power != 1 else 'W' return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols) diff --git a/cirq-ft/cirq_ft/algos/qubitization_walk_operator_test.py b/cirq-ft/cirq_ft/algos/qubitization_walk_operator_test.py index 9b54501e99c..465bb0fdd71 100644 --- a/cirq-ft/cirq_ft/algos/qubitization_walk_operator_test.py +++ b/cirq-ft/cirq_ft/algos/qubitization_walk_operator_test.py @@ -99,7 +99,7 @@ def test_qubitization_walk_operator_diagrams(): num_sites, eps = 4, 1e-1 walk = get_walk_operator_for_1d_Ising_model(num_sites, eps) # 1. Diagram for $W = SELECT.R_{L}$ - qu_regs = infra.get_named_qubits(walk.registers) + qu_regs = infra.get_named_qubits(walk.signature) walk_op = walk.on_registers(**qu_regs) circuit = cirq.Circuit(cirq.decompose_once(walk_op)) cirq.testing.assert_has_diagram( @@ -217,7 +217,7 @@ def keep(op): def test_qubitization_walk_operator_consistent_protocols_and_controlled(): gate = get_walk_operator_for_1d_Ising_model(4, 1e-1) - op = gate.on_registers(**infra.get_named_qubits(gate.registers)) + op = gate.on_registers(**infra.get_named_qubits(gate.signature)) # Test consistent repr cirq.testing.assert_equivalent_repr( gate, setup_code='import cirq\nimport cirq_ft\nimport numpy as np' diff --git a/cirq-ft/cirq_ft/algos/reflection_using_prepare.py b/cirq-ft/cirq_ft/algos/reflection_using_prepare.py index 980465f524d..97f01b59ed9 100644 --- a/cirq-ft/cirq_ft/algos/reflection_using_prepare.py +++ b/cirq-ft/cirq_ft/algos/reflection_using_prepare.py @@ -65,8 +65,8 @@ def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]: return self.prepare_gate.selection_registers @cached_property - def registers(self) -> infra.Registers: - return infra.Registers([*self.control_registers, *self.selection_registers]) + def signature(self) -> infra.Signature: + return infra.Signature([*self.control_registers, *self.selection_registers]) def decompose_from_registers( self, diff --git a/cirq-ft/cirq_ft/algos/reflection_using_prepare_test.py b/cirq-ft/cirq_ft/algos/reflection_using_prepare_test.py index b4a74c56ea7..1986b77f955 100644 --- a/cirq-ft/cirq_ft/algos/reflection_using_prepare_test.py +++ b/cirq-ft/cirq_ft/algos/reflection_using_prepare_test.py @@ -109,7 +109,7 @@ def test_reflection_using_prepare_diagram(): ) # No control gate = cirq_ft.ReflectionUsingPrepare(prepare_gate, control_val=None) - op = gate.on_registers(**infra.get_named_qubits(gate.registers)) + op = gate.on_registers(**infra.get_named_qubits(gate.signature)) circuit = greedily_allocate_ancilla(cirq.Circuit(cirq.decompose_once(op))) cirq.testing.assert_has_diagram( circuit, @@ -139,7 +139,7 @@ def test_reflection_using_prepare_diagram(): # Control on `|1>` state gate = cirq_ft.ReflectionUsingPrepare(prepare_gate, control_val=1) - op = gate.on_registers(**infra.get_named_qubits(gate.registers)) + op = gate.on_registers(**infra.get_named_qubits(gate.signature)) circuit = greedily_allocate_ancilla(cirq.Circuit(cirq.decompose_once(op))) cirq.testing.assert_has_diagram( circuit, @@ -168,7 +168,7 @@ def test_reflection_using_prepare_diagram(): # Control on `|0>` state gate = cirq_ft.ReflectionUsingPrepare(prepare_gate, control_val=0) - op = gate.on_registers(**infra.get_named_qubits(gate.registers)) + op = gate.on_registers(**infra.get_named_qubits(gate.signature)) circuit = greedily_allocate_ancilla(cirq.Circuit(cirq.decompose_once(op))) cirq.testing.assert_has_diagram( circuit, @@ -204,7 +204,7 @@ def test_reflection_using_prepare_consistent_protocols_and_controlled(): ) # No control gate = cirq_ft.ReflectionUsingPrepare(prepare_gate, control_val=None) - op = gate.on_registers(**infra.get_named_qubits(gate.registers)) + op = gate.on_registers(**infra.get_named_qubits(gate.signature)) # Test consistent repr cirq.testing.assert_equivalent_repr( gate, setup_code='import cirq\nimport cirq_ft\nimport numpy as np' diff --git a/cirq-ft/cirq_ft/algos/select_and_prepare.py b/cirq-ft/cirq_ft/algos/select_and_prepare.py index 72958b7fb4f..836d8b62ea0 100644 --- a/cirq-ft/cirq_ft/algos/select_and_prepare.py +++ b/cirq-ft/cirq_ft/algos/select_and_prepare.py @@ -53,8 +53,8 @@ def target_registers(self) -> Tuple[infra.Register, ...]: ... @cached_property - def registers(self) -> infra.Registers: - return infra.Registers( + def signature(self) -> infra.Signature: + return infra.Signature( [*self.control_registers, *self.selection_registers, *self.target_registers] ) @@ -84,5 +84,5 @@ def junk_registers(self) -> Tuple[infra.Register, ...]: return () @cached_property - def registers(self) -> infra.Registers: - return infra.Registers([*self.selection_registers, *self.junk_registers]) + def signature(self) -> infra.Signature: + return infra.Signature([*self.selection_registers, *self.junk_registers]) diff --git a/cirq-ft/cirq_ft/algos/select_swap_qrom.py b/cirq-ft/cirq_ft/algos/select_swap_qrom.py index 248dd1ab4f5..d9b23ee75d3 100644 --- a/cirq-ft/cirq_ft/algos/select_swap_qrom.py +++ b/cirq-ft/cirq_ft/algos/select_swap_qrom.py @@ -62,7 +62,7 @@ class SelectSwapQROM(infra.GateWithRegisters): target register as follows: * Divide the `N` data elements into batches of size `B` (a variable) and - load each batch simultaneously into `B` distinct target registers using the conventional + load each batch simultaneously into `B` distinct target signature using the conventional QROM. This has T-complexity `O(N / B)`. * Use `SwapWithZeroGate` to swap the `i % B`'th target register in batch number `i / B` to load `data[i]` in the 0'th target register. This has T-complexity `O(B * b)`. @@ -96,12 +96,12 @@ def __init__( SelectSwapQROM requires: - Selection register & ancilla of size `logN` for QROM data load. - 1 clean target register of size `b`. - - `B` dirty target registers, each of size `b`. + - `B` dirty target signature, each of size `b`. Similarly, to load `M` such data sequences, `SelectSwapQROM` requires: - Selection register & ancilla of size `logN` for QROM data load. - 1 clean target register of size `sum(target_bitsizes)`. - - `B` dirty target registers, each of size `sum(target_bitsizes)`. + - `B` dirty target signature, each of size `sum(target_bitsizes)`. Args: data: Sequence of integers to load in the target register. If more than one sequence @@ -153,8 +153,8 @@ def target_registers(self) -> Tuple[infra.Register, ...]: ) @cached_property - def registers(self) -> infra.Registers: - return infra.Registers([*self.selection_registers, *self.target_registers]) + def signature(self) -> infra.Signature: + return infra.Signature([*self.selection_registers, *self.target_registers]) @property def data(self) -> Tuple[Tuple[int, ...], ...]: diff --git a/cirq-ft/cirq_ft/algos/select_swap_qrom_test.py b/cirq-ft/cirq_ft/algos/select_swap_qrom_test.py index 85c64425d63..615596ee15e 100644 --- a/cirq-ft/cirq_ft/algos/select_swap_qrom_test.py +++ b/cirq-ft/cirq_ft/algos/select_swap_qrom_test.py @@ -24,7 +24,7 @@ @pytest.mark.parametrize("block_size", [None, 1, 2, 3]) def test_select_swap_qrom(data, block_size): qrom = cirq_ft.SelectSwapQROM(*data, block_size=block_size) - qubit_regs = infra.get_named_qubits(qrom.registers) + qubit_regs = infra.get_named_qubits(qrom.signature) selection = qubit_regs["selection"] selection_q, selection_r = selection[: qrom.selection_q], selection[qrom.selection_q :] targets = [qubit_regs[f"target{i}"] for i in range(len(data))] @@ -78,7 +78,7 @@ def test_qroam_diagram(): blocksize = 2 qrom = cirq_ft.SelectSwapQROM(*data, block_size=blocksize) q = cirq.LineQubit.range(cirq.num_qubits(qrom)) - circuit = cirq.Circuit(qrom.on_registers(**infra.split_qubits(qrom.registers, q))) + circuit = cirq.Circuit(qrom.on_registers(**infra.split_qubits(qrom.signature, q))) cirq.testing.assert_has_diagram( circuit, """ diff --git a/cirq-ft/cirq_ft/algos/selected_majorana_fermion.py b/cirq-ft/cirq_ft/algos/selected_majorana_fermion.py index a97eb752adb..685dc08cc92 100644 --- a/cirq-ft/cirq_ft/algos/selected_majorana_fermion.py +++ b/cirq-ft/cirq_ft/algos/selected_majorana_fermion.py @@ -34,9 +34,9 @@ class SelectedMajoranaFermionGate(unary_iteration_gate.UnaryIterationGate): Args: - selection_regs: Indexing `select` registers of type `SelectionRegister`. It also contains + selection_regs: Indexing `select` signature of type `SelectionRegister`. It also contains information about the iteration length of each selection register. - control_regs: Control registers for constructing a controlled version of the gate. + control_regs: Control signature for constructing a controlled version of the gate. target_gate: Single qubit gate to be applied to the target qubits. References: diff --git a/cirq-ft/cirq_ft/algos/selected_majorana_fermion_test.py b/cirq-ft/cirq_ft/algos/selected_majorana_fermion_test.py index 9367bcc607f..5a50af6a2bb 100644 --- a/cirq-ft/cirq_ft/algos/selected_majorana_fermion_test.py +++ b/cirq-ft/cirq_ft/algos/selected_majorana_fermion_test.py @@ -28,7 +28,7 @@ def test_selected_majorana_fermion_gate(selection_bitsize, target_bitsize, targe target_gate=target_gate, ) g = cirq_ft.testing.GateHelper(gate) - assert len(g.all_qubits) <= infra.total_bits(gate.registers) + selection_bitsize + 1 + assert len(g.all_qubits) <= infra.total_bits(gate.signature) + selection_bitsize + 1 sim = cirq.Simulator(dtype=np.complex128) for n in range(target_bitsize): @@ -67,8 +67,8 @@ def test_selected_majorana_fermion_gate_diagram(): cirq_ft.SelectionRegister('selection', selection_bitsize, target_bitsize), target_gate=cirq.X, ) - circuit = cirq.Circuit(gate.on_registers(**infra.get_named_qubits(gate.registers))) - qubits = list(q for v in infra.get_named_qubits(gate.registers).values() for q in v) + circuit = cirq.Circuit(gate.on_registers(**infra.get_named_qubits(gate.signature))) + qubits = list(q for v in infra.get_named_qubits(gate.signature).values() for q in v) cirq.testing.assert_has_diagram( circuit, """ @@ -143,8 +143,8 @@ def test_selected_majorana_fermion_gate_make_on(): cirq_ft.SelectionRegister('selection', selection_bitsize, target_bitsize), target_gate=cirq.X, ) - op = gate.on_registers(**infra.get_named_qubits(gate.registers)) + op = gate.on_registers(**infra.get_named_qubits(gate.signature)) op2 = cirq_ft.SelectedMajoranaFermionGate.make_on( - target_gate=cirq.X, **infra.get_named_qubits(gate.registers) + target_gate=cirq.X, **infra.get_named_qubits(gate.signature) ) assert op == op2 diff --git a/cirq-ft/cirq_ft/algos/state_preparation.ipynb b/cirq-ft/cirq_ft/algos/state_preparation.ipynb index ed597e1717b..e9cef3f0f42 100644 --- a/cirq-ft/cirq_ft/algos/state_preparation.ipynb +++ b/cirq-ft/cirq_ft/algos/state_preparation.ipynb @@ -83,7 +83,7 @@ "Registers:\n", " selection: The input/output register $|\\ell\\rangle$ of size lg(L) where the desired\n", " coefficient state is prepared.\n", - " temp: Work space comprised of sub registers:\n", + " temp: Work space comprised of sub signature:\n", " - sigma: A mu-sized register containing uniform probabilities for comparison against\n", " `keep`.\n", " - alt: A lg(L)-sized register of alternate indices\n", @@ -93,8 +93,8 @@ "This gate corresponds to the following operations:\n", " - UNIFORM_L on the selection register\n", " - H^mu on the sigma register\n", - " - QROM addressed by the selection register into the alt and keep registers.\n", - " - LessThanEqualGate comparing the keep and sigma registers.\n", + " - QROM addressed by the selection register into the alt and keep signature.\n", + " - LessThanEqualGate comparing the keep and sigma signature.\n", " - Coherent swap between the selection register and alt register if the comparison\n", " returns True.\n", "\n", diff --git a/cirq-ft/cirq_ft/algos/state_preparation.py b/cirq-ft/cirq_ft/algos/state_preparation.py index aa660b5ebf8..9cb0291efde 100644 --- a/cirq-ft/cirq_ft/algos/state_preparation.py +++ b/cirq-ft/cirq_ft/algos/state_preparation.py @@ -57,10 +57,10 @@ class StatePreparationAliasSampling(select_and_prepare.PrepareOracle): selecting `l` uniformly at random and then returning it with probability `keep[l] / 2**mu`; otherwise returning `alt[l]`. - Registers: + Signature: selection: The input/output register $|\ell\rangle$ of size lg(L) where the desired coefficient state is prepared. - temp: Work space comprised of sub registers: + temp: Work space comprised of sub signature: - sigma: A mu-sized register containing uniform probabilities for comparison against `keep`. - alt: A lg(L)-sized register of alternate indices @@ -70,8 +70,8 @@ class StatePreparationAliasSampling(select_and_prepare.PrepareOracle): This gate corresponds to the following operations: - UNIFORM_L on the selection register - H^mu on the sigma register - - QROM addressed by the selection register into the alt and keep registers. - - LessThanEqualGate comparing the keep and sigma registers. + - QROM addressed by the selection register into the alt and keep signature. + - LessThanEqualGate comparing the keep and sigma signature. - Coherent swap between the selection register and alt register if the comparison returns True. @@ -133,7 +133,7 @@ def selection_bitsize(self) -> int: @cached_property def junk_registers(self) -> Tuple[infra.Register, ...]: return tuple( - infra.Registers.build( + infra.Signature.build( sigma_mu=self.sigma_mu_bitsize, alt=self.alternates_bitsize, keep=self.keep_bitsize, diff --git a/cirq-ft/cirq_ft/algos/swap_network.ipynb b/cirq-ft/cirq_ft/algos/swap_network.ipynb index 22ba4635d50..cf87b7c640f 100644 --- a/cirq-ft/cirq_ft/algos/swap_network.ipynb +++ b/cirq-ft/cirq_ft/algos/swap_network.ipynb @@ -59,7 +59,7 @@ "## `MultiTargetCSwap`\n", "Implements a multi-target controlled swap unitary $CSWAP_n = |0><0| I + |1><1| SWAP_n$.\n", "\n", - "This decomposes into a qubitwise SWAP on the two target registers, and takes 14*n T-gates.\n", + "This decomposes into a qubitwise SWAP on the two target signature, and takes 14*n T-gates.\n", "\n", "#### References\n", "[Trading T-gates for dirty qubits in state preparation and unitary synthesis](https://arxiv.org/abs/1812.00954). Low et. al. 2018. See Appendix B.2.c.\n" diff --git a/cirq-ft/cirq_ft/algos/swap_network.py b/cirq-ft/cirq_ft/algos/swap_network.py index 1dd5ca88879..8a2f9362231 100644 --- a/cirq-ft/cirq_ft/algos/swap_network.py +++ b/cirq-ft/cirq_ft/algos/swap_network.py @@ -26,7 +26,7 @@ class MultiTargetCSwap(infra.GateWithRegisters): """Implements a multi-target controlled swap unitary $CSWAP_n = |0><0| I + |1><1| SWAP_n$. - This decomposes into a qubitwise SWAP on the two target registers, and takes 14*n T-gates. + This decomposes into a qubitwise SWAP on the two target signature, and takes 14*n T-gates. References: [Trading T-gates for dirty qubits in state preparation and unitary synthesis] @@ -44,8 +44,8 @@ def make_on( return cls(bitsize=len(quregs['target_x'])).on_registers(**quregs) @cached_property - def registers(self) -> infra.Registers: - return infra.Registers.build(control=1, target_x=self.bitsize, target_y=self.bitsize) + def signature(self) -> infra.Signature: + return infra.Signature.build(control=1, target_x=self.bitsize, target_y=self.bitsize) def decompose_from_registers( self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] @@ -157,8 +157,8 @@ def target_registers(self) -> Tuple[infra.Register, ...]: ) @cached_property - def registers(self) -> infra.Registers: - return infra.Registers([*self.selection_registers, *self.target_registers]) + def signature(self) -> infra.Signature: + return infra.Signature([*self.selection_registers, *self.target_registers]) def decompose_from_registers( self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] diff --git a/cirq-ft/cirq_ft/algos/swap_network_test.py b/cirq-ft/cirq_ft/algos/swap_network_test.py index 92cb5865a2a..4d1838672be 100644 --- a/cirq-ft/cirq_ft/algos/swap_network_test.py +++ b/cirq-ft/cirq_ft/algos/swap_network_test.py @@ -66,7 +66,7 @@ def test_swap_with_zero_gate(selection_bitsize, target_bitsize, n_target_registe def test_swap_with_zero_gate_diagram(): gate = cirq_ft.SwapWithZeroGate(3, 2, 4) q = cirq.LineQubit.range(cirq.num_qubits(gate)) - circuit = cirq.Circuit(gate.on_registers(**infra.split_qubits(gate.registers, q))) + circuit = cirq.Circuit(gate.on_registers(**infra.split_qubits(gate.signature, q))) cirq.testing.assert_has_diagram( circuit, """ diff --git a/cirq-ft/cirq_ft/algos/unary_iteration.ipynb b/cirq-ft/cirq_ft/algos/unary_iteration.ipynb index 4eabc65a0af..e2920d7b926 100644 --- a/cirq-ft/cirq_ft/algos/unary_iteration.ipynb +++ b/cirq-ft/cirq_ft/algos/unary_iteration.ipynb @@ -248,7 +248,7 @@ "source": [ "## Quantum Circuit\n", "\n", - "We can translate the boolean logic to reversible, quantum logic. It is instructive to start from the suboptimal total control quantum circuit for comparison purposes. We can build this as in the sympy boolean-logic case by adding controlled X operations to the target registers, with the controls on the selection registers toggled on or off according to the binary representation of the selection index.\n", + "We can translate the boolean logic to reversible, quantum logic. It is instructive to start from the suboptimal total control quantum circuit for comparison purposes. We can build this as in the sympy boolean-logic case by adding controlled X operations to the target signature, with the controls on the selection signature toggled on or off according to the binary representation of the selection index.\n", "\n", "Let us first build a GateWithRegisters object to implement the circuit" ] @@ -262,7 +262,7 @@ "source": [ "import cirq\n", "from cirq._compat import cached_property\n", - "from cirq_ft import Registers, GateWithRegisters\n", + "from cirq_ft import Signature, GateWithRegisters\n", "from cirq_ft.infra.bit_tools import iter_bits\n", "\n", "class TotallyControlledNot(GateWithRegisters):\n", @@ -273,12 +273,12 @@ " self._control_bitsize = control_bitsize\n", "\n", " @cached_property\n", - " def registers(self) -> Registers:\n", - " return Registers(\n", + " def signature(self) -> Signature:\n", + " return Signature(\n", " [\n", - " *Registers.build(control=self._control_bitsize),\n", - " *Registers.build(selection=self._selection_bitsize),\n", - " *Registers.build(target=self._target_bitsize)\n", + " *Signature.build(control=self._control_bitsize),\n", + " *Signature.build(selection=self._selection_bitsize),\n", + " *Signature.build(target=self._target_bitsize)\n", " ]\n", " )\n", "\n", @@ -471,7 +471,7 @@ "metadata": {}, "outputs": [], "source": [ - "from cirq_ft import Register, Registers, SelectionRegister, UnaryIterationGate\n", + "from cirq_ft import Register, SelectionRegister, UnaryIterationGate\n", "from cirq._compat import cached_property\n", "\n", "class ApplyXToLthQubit(UnaryIterationGate):\n", diff --git a/cirq-ft/cirq_ft/algos/unary_iteration_gate.py b/cirq-ft/cirq_ft/algos/unary_iteration_gate.py index d72ab7381ce..74375646609 100644 --- a/cirq-ft/cirq_ft/algos/unary_iteration_gate.py +++ b/cirq-ft/cirq_ft/algos/unary_iteration_gate.py @@ -282,8 +282,8 @@ def target_registers(self) -> Tuple[infra.Register, ...]: pass @cached_property - def registers(self) -> infra.Registers: - return infra.Registers( + def signature(self) -> infra.Signature: + return infra.Signature( [*self.control_registers, *self.selection_registers, *self.target_registers] ) @@ -295,14 +295,14 @@ def extra_registers(self) -> Tuple[infra.Register, ...]: def nth_operation( self, context: cirq.DecompositionContext, control: cirq.Qid, **kwargs ) -> cirq.OP_TREE: - """Apply nth operation on the target registers when selection registers store `n`. + """Apply nth operation on the target signature when selection signature store `n`. The `UnaryIterationGate` class is a mixin that represents a coherent for-loop over - different indices (i.e. selection registers). This method denotes the "body" of the + different indices (i.e. selection signature). This method denotes the "body" of the for-loop, which is executed `self.selection_registers.total_iteration_size` times and each - iteration represents a unique combination of values stored in selection registers. For each + iteration represents a unique combination of values stored in selection signature. For each call, the method should return the operations that should be applied to the target - registers, given the values stored in selection registers. + signature, given the values stored in selection signature. The derived classes should specify the following arguments as `**kwargs`: 1) `control: cirq.Qid`: A qubit which can be used as a control to selectively @@ -345,7 +345,7 @@ def _break_early(self, selection_index_prefix: Tuple[int, ...], l: int, r: int) representing range `[l, r)`. If True, the internal node is considered equivalent to a leaf node and thus, `self.nth_operation` will be called for only integer `l` in the range [l, r). - When the `UnaryIteration` class is constructed using multiple selection registers, i.e. we + When the `UnaryIteration` class is constructed using multiple selection signature, i.e. we wish to perform nested coherent for-loops, a unary iteration segment tree is constructed corresponding to each nested coherent for-loop. For every such unary iteration segment tree, the `_break_early` condition is checked by passing the `selection_index_prefix` tuple. @@ -398,7 +398,7 @@ def unary_iteration_loops( Returns: `cirq.OP_TREE` implementing `num_loops` nested coherent for-loops, with operations returned by `self.nth_operation` applied conditionally to the target register based - on values of selection registers. + on values of selection signature. """ if nested_depth == num_loops: yield self.nth_operation( diff --git a/cirq-ft/cirq_ft/algos/unary_iteration_gate_test.py b/cirq-ft/cirq_ft/algos/unary_iteration_gate_test.py index 0c754fc5980..7450e5e496a 100644 --- a/cirq-ft/cirq_ft/algos/unary_iteration_gate_test.py +++ b/cirq-ft/cirq_ft/algos/unary_iteration_gate_test.py @@ -99,7 +99,7 @@ def selection_registers(self) -> Tuple[cirq_ft.SelectionRegister, ...]: @cached_property def target_registers(self) -> Tuple[cirq_ft.Register, ...]: return tuple( - cirq_ft.Registers.build( + cirq_ft.Signature.build( t1=self._target_shape[0], t2=self._target_shape[1], t3=self._target_shape[2] ) ) @@ -125,7 +125,7 @@ def test_multi_dimensional_unary_iteration_gate(target_shape: Tuple[int, int, in g = cirq_ft.testing.GateHelper(gate, context=cirq.DecompositionContext(greedy_mm)) assert ( len(g.all_qubits) - <= infra.total_bits(gate.registers) + infra.total_bits(gate.selection_registers) - 1 + <= infra.total_bits(gate.signature) + infra.total_bits(gate.selection_registers) - 1 ) max_i, max_j, max_k = target_shape diff --git a/cirq-ft/cirq_ft/infra/__init__.py b/cirq-ft/cirq_ft/infra/__init__.py index 02f503110ca..13ea10ddca0 100644 --- a/cirq-ft/cirq_ft/infra/__init__.py +++ b/cirq-ft/cirq_ft/infra/__init__.py @@ -15,7 +15,8 @@ from cirq_ft.infra.gate_with_registers import ( GateWithRegisters, Register, - Registers, + Signature, + Side, SelectionRegister, total_bits, split_qubits, diff --git a/cirq-ft/cirq_ft/infra/gate_with_registers.ipynb b/cirq-ft/cirq_ft/infra/gate_with_registers.ipynb index ef72a1e1479..8ccc674942c 100644 --- a/cirq-ft/cirq_ft/infra/gate_with_registers.ipynb +++ b/cirq-ft/cirq_ft/infra/gate_with_registers.ipynb @@ -37,9 +37,9 @@ "id": "c0833444", "metadata": {}, "source": [ - "## `Registers`\n", + "## `Signature`\n", "\n", - "`Register` objects have a name, a bitsize and a shape. `Registers` is an ordered collection of `Register` with some helpful methods." + "`Register` objects have a name, a bitsize and a shape. `Signature` is an ordered collection of `Register` with some helpful methods." ] }, { @@ -49,7 +49,7 @@ "metadata": {}, "outputs": [], "source": [ - "from cirq_ft import Register, Registers, infra\n", + "from cirq_ft import Register, Signature, infra\n", "\n", "control_reg = Register(name='control', bitsize=2)\n", "target_reg = Register(name='target', bitsize=3)\n", @@ -63,7 +63,7 @@ "metadata": {}, "outputs": [], "source": [ - "r = Registers([control_reg, target_reg])\n", + "r = Signature([control_reg, target_reg])\n", "r" ] }, @@ -82,7 +82,7 @@ "metadata": {}, "outputs": [], "source": [ - "r == Registers.build(\n", + "r == Signature.build(\n", " control=2,\n", " target=3,\n", ")" @@ -109,8 +109,8 @@ "class MyGate(GateWithRegisters):\n", " \n", " @property\n", - " def registers(self):\n", - " return Registers.build(\n", + " def signature(self):\n", + " return Signature.build(\n", " control=2,\n", " target=3,\n", " )\n", @@ -152,7 +152,7 @@ "id": "2d725646", "metadata": {}, "source": [ - "The `Registers` object can allocate a dictionary of `cirq.NamedQubit` that we can use to turn our `Gate` into an `Operation`. `GateWithRegisters` exposes an `on_registers` method to compliment Cirq's `on` method where we can use names to make sure each qubit is used appropriately." + "The `Signature` object can allocate a dictionary of `cirq.NamedQubit` that we can use to turn our `Gate` into an `Operation`. `GateWithRegisters` exposes an `on_registers` method to compliment Cirq's `on` method where we can use names to make sure each qubit is used appropriately." ] }, { @@ -162,7 +162,7 @@ "metadata": {}, "outputs": [], "source": [ - "r = gate.registers\n", + "r = gate.signature\n", "quregs = infra.get_named_qubits(r)\n", "quregs" ] diff --git a/cirq-ft/cirq_ft/infra/gate_with_registers.py b/cirq-ft/cirq_ft/infra/gate_with_registers.py index 624397ab479..d514e0cd07b 100644 --- a/cirq-ft/cirq_ft/infra/gate_with_registers.py +++ b/cirq-ft/cirq_ft/infra/gate_with_registers.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import enum import abc import itertools -from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union, overload +from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union, overload, Iterator from numpy.typing import NDArray import attr @@ -22,13 +23,33 @@ import numpy as np +class Side(enum.Flag): + """Denote LEFT, RIGHT, or THRU signature. + + LEFT signature serve as input lines (only) to the Gate. RIGHT signature are output + lines (only) from the Gate. THRU signature are both input and output. + + Traditional unitary operations will have THRU signature that operate on a collection of + qubits which are then made available to following operations. RIGHT and LEFT signature + imply allocation, deallocation, or reshaping of the signature. + """ + + LEFT = enum.auto() + RIGHT = enum.auto() + THRU = LEFT | RIGHT + + @attr.frozen class Register: """A quantum register used to define the input/output API of a `cirq_ft.GateWithRegister` - Args: + Attributes: name: The string name of the register - shape: Shape of the multi-dimensional qubit register. + bitsize: The number of (qu)bits in the register. + shape: A tuple of integer dimensions to declare a multidimensional register. The + total number of bits is the product of entries in this tuple times `bitsize`. + side: Whether this is a left, right, or thru register. See the documentation for `Side` + for more information. """ name: str @@ -36,6 +57,7 @@ class Register: shape: Tuple[int, ...] = attr.field( converter=lambda v: (v,) if isinstance(v, int) else tuple(v), default=() ) + side: Side = Side.THRU @bitsize.validator def bitsize_validator(self, attribute, value): @@ -54,11 +76,17 @@ def total_bits(self) -> int: return self.bitsize * int(np.prod(self.shape)) def __repr__(self): - return f'cirq_ft.Register(name="{self.name}", bitsize={self.bitsize}, shape={self.shape})' + return ( + f'cirq_ft.Register(' + f'name="{self.name}", ' + f'bitsize={self.bitsize}, ' + f'shape={self.shape}, ' + f'side=cirq_ft.infra.{self.side})' + ) def total_bits(registers: Iterable[Register]) -> int: - """Sum of `reg.total_bits()` for each register `reg` in input `registers`.""" + """Sum of `reg.total_bits()` for each register `reg` in input `signature`.""" return sum(reg.total_bits() for reg in registers) @@ -99,7 +127,7 @@ def merge_qubits( def get_named_qubits(registers: Iterable[Register]) -> Dict[str, NDArray[cirq.Qid]]: - """Returns a dictionary of appropriately shaped named qubit registers for input `registers`.""" + """Returns a dictionary of appropriately shaped named qubit signature for input `signature`.""" def _qubit_array(reg: Register): qubits = np.empty(reg.shape + (reg.bitsize,), dtype=object) @@ -124,7 +152,7 @@ def _qubits_for_reg(reg: Register): return {reg.name: _qubits_for_reg(reg) for reg in registers} -class Registers: +class Signature: """An ordered collection of `cirq_ft.Register`. Args: @@ -133,15 +161,16 @@ class Registers: def __init__(self, registers: Iterable[Register]): self._registers = tuple(registers) - self._register_dict = {r.name: r for r in self._registers} - if len(self._registers) != len(self._register_dict): + self._lefts = {r.name: r for r in self._registers if r.side & Side.LEFT} + self._rights = {r.name: r for r in self._registers if r.side & Side.RIGHT} + if len(set(self._lefts) | set(self._rights)) != len(self._registers): raise ValueError("Please provide unique register names.") def __repr__(self): - return f'cirq_ft.Registers({self._registers})' + return f'cirq_ft.Signature({self._registers})' @classmethod - def build(cls, **registers: int) -> 'Registers': + def build(cls, **registers: int) -> 'Signature': return cls(Register(name=k, bitsize=v) for k, v in registers.items() if v > 0) @overload @@ -149,27 +178,24 @@ def __getitem__(self, key: int) -> Register: pass @overload - def __getitem__(self, key: str) -> Register: - pass - - @overload - def __getitem__(self, key: slice) -> 'Registers': + def __getitem__(self, key: slice) -> Tuple[Register, ...]: pass def __getitem__(self, key): - if isinstance(key, slice): - return Registers(self._registers[key]) - elif isinstance(key, int): - return self._registers[key] - elif isinstance(key, str): - return self._register_dict[key] - else: - raise IndexError(f"key {key} must be of the type str/int/slice.") - - def __contains__(self, item: str) -> bool: - return item in self._register_dict - - def __iter__(self): + return self._registers[key] + + def get_left(self, name: str) -> Register: + """Get a left register by name.""" + return self._lefts[name] + + def get_right(self, name: str) -> Register: + """Get a right register by name.""" + return self._rights[name] + + def __contains__(self, item: Register) -> bool: + return item in self._registers + + def __iter__(self) -> Iterator[Register]: yield from self._registers def __len__(self) -> int: @@ -233,6 +259,7 @@ class SelectionRegister(Register): shape: Tuple[int, ...] = attr.field( converter=lambda v: (v,) if isinstance(v, int) else tuple(v), default=() ) + side: Side = Side.THRU @iteration_length.default def _default_iteration_length(self): @@ -281,8 +308,8 @@ class GateWithRegisters(cirq.Gate, metaclass=abc.ABCMeta): ... bitsize: int ... ... @property - ... def registers(self) -> cirq_ft.Registers: - ... return cirq_ft.Registers.build(ctrl=1, x=self.bitsize, y=self.bitsize) + ... def signature(self) -> cirq_ft.Signature: + ... return cirq_ft.Signature.build(ctrl=1, x=self.bitsize, y=self.bitsize) ... ... def decompose_from_registers(self, context, ctrl, x, y) -> cirq.OP_TREE: ... yield [cirq.CSWAP(*ctrl, qx, qy) for qx, qy in zip(x, y)] @@ -305,11 +332,11 @@ class GateWithRegisters(cirq.Gate, metaclass=abc.ABCMeta): @property @abc.abstractmethod - def registers(self) -> Registers: + def signature(self) -> Signature: ... def _num_qubits_(self) -> int: - return total_bits(self.registers) + return total_bits(self.signature) def decompose_from_registers( self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] @@ -319,7 +346,7 @@ def decompose_from_registers( def _decompose_with_context_( self, qubits: Sequence[cirq.Qid], context: Optional[cirq.DecompositionContext] = None ) -> cirq.OP_TREE: - qubit_regs = split_qubits(self.registers, qubits) + qubit_regs = split_qubits(self.signature, qubits) if context is None: context = cirq.DecompositionContext(cirq.ops.SimpleQubitManager()) return self.decompose_from_registers(context=context, **qubit_regs) @@ -330,7 +357,7 @@ def _decompose_(self, qubits: Sequence[cirq.Qid]) -> cirq.OP_TREE: def on_registers( self, **qubit_regs: Union[cirq.Qid, Sequence[cirq.Qid], NDArray[cirq.Qid]] ) -> cirq.Operation: - return self.on(*merge_qubits(self.registers, **qubit_regs)) + return self.on(*merge_qubits(self.signature, **qubit_regs)) def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo: """Default diagram info that uses register names to name the boxes in multi-qubit gates. @@ -338,7 +365,7 @@ def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.Circ Descendants can override this method with more meaningful circuit diagram information. """ wire_symbols = [] - for reg in self.registers: + for reg in self.signature: wire_symbols += [reg.name] * reg.total_bits() wire_symbols[0] = self.__class__.__name__ diff --git a/cirq-ft/cirq_ft/infra/gate_with_registers_test.py b/cirq-ft/cirq_ft/infra/gate_with_registers_test.py index 77e60aacbe8..613d44f870a 100644 --- a/cirq-ft/cirq_ft/infra/gate_with_registers_test.py +++ b/cirq-ft/cirq_ft/infra/gate_with_registers_test.py @@ -30,27 +30,31 @@ def test_register(): def test_registers(): - r1 = cirq_ft.Register("r1", 5) - r2 = cirq_ft.Register("r2", 2) + r1 = cirq_ft.Register("r1", 5, side=cirq_ft.infra.Side.LEFT) + r2 = cirq_ft.Register("r2", 2, side=cirq_ft.infra.Side.RIGHT) r3 = cirq_ft.Register("r3", 1) - regs = cirq_ft.Registers([r1, r2, r3]) + regs = cirq_ft.Signature([r1, r2, r3]) assert len(regs) == 3 cirq.testing.assert_equivalent_repr(regs, setup_code='import cirq_ft') with pytest.raises(ValueError, match="unique"): - _ = cirq_ft.Registers([r1, r1]) + _ = cirq_ft.Signature([r1, r1]) assert regs[0] == r1 assert regs[1] == r2 assert regs[2] == r3 - assert regs[0:1] == cirq_ft.Registers([r1]) - assert regs[0:2] == cirq_ft.Registers([r1, r2]) - assert regs[1:3] == cirq_ft.Registers([r2, r3]) + assert regs[0:1] == tuple([r1]) + assert regs[0:2] == tuple([r1, r2]) + assert regs[1:3] == tuple([r2, r3]) - assert regs["r1"] == r1 - assert regs["r2"] == r2 - assert regs["r3"] == r3 + assert regs.get_left("r1") == r1 + assert regs.get_right("r2") == r2 + assert regs.get_left("r3") == r3 + + assert r1 in regs + assert r2 in regs + assert r3 in regs assert list(regs) == [r1, r2, r3] @@ -85,7 +89,7 @@ def test_registers(): # initial registers. for reg_order in [[r1, r2, r3], [r2, r3, r1]]: flat_named_qubits = [ - q for v in get_named_qubits(cirq_ft.Registers(reg_order)).values() for q in v + q for v in get_named_qubits(cirq_ft.Signature(reg_order)).values() for q in v ] expected_qubits = [q for r in reg_order for q in expected_named_qubits[r.name]] assert flat_named_qubits == expected_qubits @@ -109,43 +113,42 @@ def test_selection_registers_consistent(): with pytest.raises(ValueError, match="should be flat"): _ = cirq_ft.SelectionRegister('a', bitsize=1, shape=(3, 5), iteration_length=5) - selection_reg = cirq_ft.Registers( + selection_reg = cirq_ft.Signature( [ cirq_ft.SelectionRegister('n', bitsize=3, iteration_length=5), cirq_ft.SelectionRegister('m', bitsize=4, iteration_length=12), ] ) assert selection_reg[0] == cirq_ft.SelectionRegister('n', 3, 5) - assert selection_reg['n'] == cirq_ft.SelectionRegister('n', 3, 5) assert selection_reg[1] == cirq_ft.SelectionRegister('m', 4, 12) - assert selection_reg[:1] == cirq_ft.Registers([cirq_ft.SelectionRegister('n', 3, 5)]) + assert selection_reg[:1] == tuple([cirq_ft.SelectionRegister('n', 3, 5)]) def test_registers_getitem_raises(): - g = cirq_ft.Registers.build(a=4, b=3, c=2) - with pytest.raises(IndexError, match="must be of the type"): + g = cirq_ft.Signature.build(a=4, b=3, c=2) + with pytest.raises(TypeError, match="indices must be integers or slices"): _ = g[2.5] - selection_reg = cirq_ft.Registers( + selection_reg = cirq_ft.Signature( [cirq_ft.SelectionRegister('n', bitsize=3, iteration_length=5)] ) - with pytest.raises(IndexError, match='must be of the type'): + with pytest.raises(TypeError, match='indices must be integers or slices'): _ = selection_reg[2.5] def test_registers_build(): - regs1 = cirq_ft.Registers([cirq_ft.Register("r1", 5), cirq_ft.Register("r2", 2)]) - regs2 = cirq_ft.Registers.build(r1=5, r2=2) + regs1 = cirq_ft.Signature([cirq_ft.Register("r1", 5), cirq_ft.Register("r2", 2)]) + regs2 = cirq_ft.Signature.build(r1=5, r2=2) assert regs1 == regs2 class _TestGate(cirq_ft.GateWithRegisters): @property - def registers(self) -> cirq_ft.Registers: + def signature(self) -> cirq_ft.Signature: r1 = cirq_ft.Register("r1", 5) r2 = cirq_ft.Register("r2", 2) r3 = cirq_ft.Register("r3", 1) - regs = cirq_ft.Registers([r1, r2, r3]) + regs = cirq_ft.Signature([r1, r2, r3]) return regs def decompose_from_registers(self, *, context, **quregs) -> cirq.OP_TREE: diff --git a/cirq-ft/cirq_ft/infra/jupyter_tools.py b/cirq-ft/cirq_ft/infra/jupyter_tools.py index a9ae4817ef7..b6b98cc9b14 100644 --- a/cirq-ft/cirq_ft/infra/jupyter_tools.py +++ b/cirq-ft/cirq_ft/infra/jupyter_tools.py @@ -13,7 +13,7 @@ # limitations under the License. from pathlib import Path -from typing import Optional +from typing import Iterable import cirq import cirq.contrib.svg.svg as ccsvg @@ -65,14 +65,14 @@ def _map_func(op: cirq.Operation, _): def svg_circuit( circuit: 'cirq.AbstractCircuit', - registers: Optional[gate_with_registers.Registers] = None, + registers: Iterable[gate_with_registers.Register] = (), include_costs: bool = False, ): """Return an SVG object representing a circuit. Args: circuit: The circuit to draw. - registers: Optional `Registers` object to order the qubits. + registers: Optional `Signature` object to order the qubits. include_costs: If true, each operation is annotated with it's T-complexity cost. Raises: @@ -81,7 +81,7 @@ def svg_circuit( if len(circuit) == 0: raise ValueError("Circuit is empty.") - if registers is not None: + if registers: qubit_order = cirq.QubitOrder.explicit( merge_qubits(registers, **get_named_qubits(registers)), fallback=cirq.QubitOrder.DEFAULT ) diff --git a/cirq-ft/cirq_ft/infra/t_complexity.ipynb b/cirq-ft/cirq_ft/infra/t_complexity.ipynb index 3a4c7c4596a..3697188b98e 100644 --- a/cirq-ft/cirq_ft/infra/t_complexity.ipynb +++ b/cirq-ft/cirq_ft/infra/t_complexity.ipynb @@ -61,7 +61,7 @@ "# And of two qubits\n", "gate = And() # create an And gate\n", "# create an operation\n", - "operation = gate.on_registers(**infra.get_named_qubits(gate.registers))\n", + "operation = gate.on_registers(**infra.get_named_qubits(gate.signature))\n", "# this operation doesn't directly support TComplexity but it's decomposable and its components are simple.\n", "print(t_complexity(operation))" ] @@ -82,7 +82,7 @@ "outputs": [], "source": [ "gate = And() ** -1 # adjoint of And\n", - "operation = gate.on_registers(**infra.get_named_qubits(gate.registers))\n", + "operation = gate.on_registers(**infra.get_named_qubits(gate.signature))\n", "# the deomposition is H, measure, CZ, and Reset\n", "print(t_complexity(operation))" ] @@ -104,7 +104,7 @@ "source": [ "n = 5\n", "gate = And((1, )*n)\n", - "operation = gate.on_registers(**infra.get_named_qubits(gate.registers))\n", + "operation = gate.on_registers(**infra.get_named_qubits(gate.signature))\n", "print(t_complexity(operation))" ] }, @@ -122,7 +122,7 @@ " for n in range(2, n_max + 2):\n", " n_controls.append(n)\n", " gate = And(cv=(1, )*n)\n", - " op = gate.on_registers(**infra.get_named_qubits(gate.registers))\n", + " op = gate.on_registers(**infra.get_named_qubits(gate.signature))\n", " c = t_complexity(op)\n", " t_count.append(c.t)\n", " return n_controls, t_count" diff --git a/cirq-ft/cirq_ft/infra/t_complexity_protocol_test.py b/cirq-ft/cirq_ft/infra/t_complexity_protocol_test.py index f28f3fc5e6a..30ecb633d72 100644 --- a/cirq-ft/cirq_ft/infra/t_complexity_protocol_test.py +++ b/cirq-ft/cirq_ft/infra/t_complexity_protocol_test.py @@ -30,8 +30,8 @@ class DoesNotSupportTComplexity: class SupportsTComplexityGateWithRegisters(cirq_ft.GateWithRegisters): @property - def registers(self) -> cirq_ft.Registers: - return cirq_ft.Registers.build(s=1, t=2) + def signature(self) -> cirq_ft.Signature: + return cirq_ft.Signature.build(s=1, t=2) def _t_complexity_(self) -> cirq_ft.TComplexity: return cirq_ft.TComplexity(t=1, clifford=2) @@ -109,11 +109,11 @@ def test_operations(): assert cirq_ft.t_complexity(cirq.T(q)) == cirq_ft.TComplexity(t=1) gate = cirq_ft.And() - op = gate.on_registers(**infra.get_named_qubits(gate.registers)) + op = gate.on_registers(**infra.get_named_qubits(gate.signature)) assert cirq_ft.t_complexity(op) == cirq_ft.TComplexity(t=4, clifford=9) gate = cirq_ft.And() ** -1 - op = gate.on_registers(**infra.get_named_qubits(gate.registers)) + op = gate.on_registers(**infra.get_named_qubits(gate.signature)) assert cirq_ft.t_complexity(op) == cirq_ft.TComplexity(clifford=4) diff --git a/cirq-ft/cirq_ft/infra/testing.py b/cirq-ft/cirq_ft/infra/testing.py index 31802d5300d..088b46dab45 100644 --- a/cirq-ft/cirq_ft/infra/testing.py +++ b/cirq-ft/cirq_ft/infra/testing.py @@ -37,13 +37,13 @@ class GateHelper: context: cirq.DecompositionContext = cirq.DecompositionContext(cirq.ops.SimpleQubitManager()) @cached_property - def r(self) -> gate_with_registers.Registers: - """The Registers system for the gate.""" - return self.gate.registers + def r(self) -> gate_with_registers.Signature: + """The Signature system for the gate.""" + return self.gate.signature @cached_property def quregs(self) -> Dict[str, NDArray[cirq.Qid]]: # type: ignore[type-var] - """A dictionary of named qubits appropriate for the registers for the gate.""" + """A dictionary of named qubits appropriate for the signature for the gate.""" return get_named_qubits(self.r) @cached_property diff --git a/cirq-ft/cirq_ft/infra/testing_test.py b/cirq-ft/cirq_ft/infra/testing_test.py index ad3c8a83674..15f0b0a5f06 100644 --- a/cirq-ft/cirq_ft/infra/testing_test.py +++ b/cirq-ft/cirq_ft/infra/testing_test.py @@ -34,7 +34,7 @@ def test_assert_circuit_inp_out_cirqsim(): def test_gate_helper(): g = cirq_ft.testing.GateHelper(cirq_ft.And(cv=(1, 0, 1, 0))) assert g.gate == cirq_ft.And(cv=(1, 0, 1, 0)) - assert g.r == cirq_ft.Registers.build(control=4, ancilla=2, target=1) + assert g.r == cirq_ft.Signature.build(control=4, ancilla=2, target=1) expected_quregs = { 'control': cirq.NamedQubit.range(4, prefix='control'), 'ancilla': cirq.NamedQubit.range(2, prefix='ancilla'), From 1948e732b197fa0850d3695f0d4c0c10c97502c1 Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Mon, 25 Sep 2023 17:28:22 -0700 Subject: [PATCH 18/19] Change signature of `cirq_ft.And` gate to use directional registers like Qualtran (#6302) * Change signature of gate to use directional registers like Qualtran * Remove debug prints --- cirq-ft/cirq_ft/algos/and_gate.py | 17 +- cirq-ft/cirq_ft/algos/and_gate_test.py | 145 +++++++++--------- cirq-ft/cirq_ft/algos/hubbard_model.py | 4 +- .../algos/multi_control_multi_target_pauli.py | 8 +- .../algos/prepare_uniform_superposition.py | 4 +- cirq-ft/cirq_ft/algos/qrom.py | 4 +- cirq-ft/cirq_ft/algos/unary_iteration_gate.py | 4 +- cirq-ft/cirq_ft/infra/jupyter_tools_test.py | 2 +- cirq-ft/cirq_ft/infra/testing_test.py | 12 +- 9 files changed, 112 insertions(+), 88 deletions(-) diff --git a/cirq-ft/cirq_ft/algos/and_gate.py b/cirq-ft/cirq_ft/algos/and_gate.py index 1d54d1a66b4..1da187793cd 100644 --- a/cirq-ft/cirq_ft/algos/and_gate.py +++ b/cirq-ft/cirq_ft/algos/and_gate.py @@ -61,7 +61,16 @@ def _validate_cv(self, attribute, value): @cached_property def signature(self) -> infra.Signature: - return infra.Signature.build(control=len(self.cv), ancilla=len(self.cv) - 2, target=1) + one_side = infra.Side.RIGHT if not self.adjoint else infra.Side.LEFT + n_cv = len(self.cv) + junk_reg = [infra.Register('junk', 1, shape=n_cv - 2, side=one_side)] if n_cv > 2 else [] + return infra.Signature( + [ + infra.Register('ctrl', 1, shape=n_cv), + *junk_reg, + infra.Register('target', 1, side=one_side), + ] + ) def __pow__(self, power: int) -> "And": if power == 1: @@ -142,9 +151,9 @@ def decompose_from_registers( self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] ) -> cirq.OP_TREE: control, ancilla, target = ( - quregs['control'], - quregs.get('ancilla', np.array([])), - quregs['target'], + quregs['ctrl'].flatten(), + quregs.get('junk', np.array([])).flatten(), + quregs['target'].flatten(), ) if len(self.cv) == 2: yield self._decompose_single_and( diff --git a/cirq-ft/cirq_ft/algos/and_gate_test.py b/cirq-ft/cirq_ft/algos/and_gate_test.py index f79962cfdfa..0ab39868d42 100644 --- a/cirq-ft/cirq_ft/algos/and_gate_test.py +++ b/cirq-ft/cirq_ft/algos/and_gate_test.py @@ -46,7 +46,7 @@ def random_cv(n: int) -> List[int]: def test_multi_controlled_and_gate(cv: List[int]): gate = cirq_ft.And(cv) r = gate.signature - assert r.get_left('ancilla').total_bits() == r.get_left('control').total_bits() - 2 + assert r.get_right('junk').total_bits() == r.get_left('ctrl').total_bits() - 2 quregs = infra.get_named_qubits(r) and_op = gate.on_registers(**quregs) circuit = cirq.Circuit(and_op) @@ -55,7 +55,7 @@ def test_multi_controlled_and_gate(cv: List[int]): qubit_order = infra.merge_qubits(gate.signature, **quregs) for input_control in input_controls: - initial_state = input_control + [0] * (r.get_left('ancilla').total_bits() + 1) + initial_state = input_control + [0] * (r.get_right('junk').total_bits() + 1) result = cirq.Simulator(dtype=np.complex128).simulate( circuit, initial_state=initial_state, qubit_order=qubit_order ) @@ -80,64 +80,67 @@ def test_and_gate_diagram(): gate = cirq_ft.And((1, 0, 1, 0, 1, 0)) qubit_regs = infra.get_named_qubits(gate.signature) op = gate.on_registers(**qubit_regs) - # Qubit order should be alternating (control, ancilla) pairs. - c_and_a = sum(zip(qubit_regs["control"][1:], qubit_regs["ancilla"]), ()) + ( - qubit_regs["control"][-1], + ctrl, junk, target = ( + qubit_regs["ctrl"].flatten(), + qubit_regs["junk"].flatten(), + qubit_regs['target'].flatten(), ) - qubit_order = np.concatenate([qubit_regs["control"][0:1], c_and_a, qubit_regs["target"]]) + # Qubit order should be alternating (control, ancilla) pairs. + c_and_a = sum(zip(ctrl[1:], junk), ()) + (ctrl[-1],) + qubit_order = np.concatenate([ctrl[0:1], c_and_a, target]) # Test diagrams. cirq.testing.assert_has_diagram( cirq.Circuit(op), """ -control0: ───@───── - │ -control1: ───(0)─── - │ -ancilla0: ───Anc─── - │ -control2: ───@───── - │ -ancilla1: ───Anc─── - │ -control3: ───(0)─── - │ -ancilla2: ───Anc─── - │ -control4: ───@───── - │ -ancilla3: ───Anc─── - │ -control5: ───(0)─── - │ -target: ─────And─── +ctrl[0]: ───@───── + │ +ctrl[1]: ───(0)─── + │ +junk[0]: ───Anc─── + │ +ctrl[2]: ───@───── + │ +junk[1]: ───Anc─── + │ +ctrl[3]: ───(0)─── + │ +junk[2]: ───Anc─── + │ +ctrl[4]: ───@───── + │ +junk[3]: ───Anc─── + │ +ctrl[5]: ───(0)─── + │ +target: ────And─── """, qubit_order=qubit_order, ) cirq.testing.assert_has_diagram( cirq.Circuit(op**-1), """ -control0: ───@────── - │ -control1: ───(0)──── - │ -ancilla0: ───Anc──── - │ -control2: ───@────── - │ -ancilla1: ───Anc──── - │ -control3: ───(0)──── - │ -ancilla2: ───Anc──── - │ -control4: ───@────── - │ -ancilla3: ───Anc──── - │ -control5: ───(0)──── - │ -target: ─────And†─── - """, +ctrl[0]: ───@────── + │ +ctrl[1]: ───(0)──── + │ +junk[0]: ───Anc──── + │ +ctrl[2]: ───@────── + │ +junk[1]: ───Anc──── + │ +ctrl[3]: ───(0)──── + │ +junk[2]: ───Anc──── + │ +ctrl[4]: ───@────── + │ +junk[3]: ───Anc──── + │ +ctrl[5]: ───(0)──── + │ +target: ────And†─── +""", qubit_order=qubit_order, ) # Test diagram of decomposed 3-qubit and ladder. @@ -147,28 +150,28 @@ def test_and_gate_diagram(): cirq.testing.assert_has_diagram( decomposed_circuit, """ -control0: ───@─────────────────────────────────────────────────────────@────── - │ │ -control1: ───(0)───────────────────────────────────────────────────────(0)──── - │ │ -ancilla0: ───And───@────────────────────────────────────────────@──────And†─── - │ │ -control2: ─────────@────────────────────────────────────────────@───────────── - │ │ -ancilla1: ─────────And───@───────────────────────────────@──────And†────────── - │ │ -control3: ───────────────(0)─────────────────────────────(0)────────────────── - │ │ -ancilla2: ───────────────And───@──────────────────@──────And†───────────────── - │ │ -control4: ─────────────────────@──────────────────@─────────────────────────── - │ │ -ancilla3: ─────────────────────And───@─────@──────And†──────────────────────── - │ │ -control5: ───────────────────────────(0)───(0)──────────────────────────────── - │ │ -target: ─────────────────────────────And───And†─────────────────────────────── - """, +ctrl[0]: ───@─────────────────────────────────────────────────────────@────── + │ │ +ctrl[1]: ───(0)───────────────────────────────────────────────────────(0)──── + │ │ +junk[0]: ───And───@────────────────────────────────────────────@──────And†─── + │ │ +ctrl[2]: ─────────@────────────────────────────────────────────@───────────── + │ │ +junk[1]: ─────────And───@───────────────────────────────@──────And†────────── + │ │ +ctrl[3]: ───────────────(0)─────────────────────────────(0)────────────────── + │ │ +junk[2]: ───────────────And───@──────────────────@──────And†───────────────── + │ │ +ctrl[4]: ─────────────────────@──────────────────@─────────────────────────── + │ │ +junk[3]: ─────────────────────And───@─────@──────And†──────────────────────── + │ │ +ctrl[5]: ───────────────────────────(0)───(0)──────────────────────────────── + │ │ +target: ────────────────────────────And───And†─────────────────────────────── +""", qubit_order=qubit_order, ) diff --git a/cirq-ft/cirq_ft/algos/hubbard_model.py b/cirq-ft/cirq_ft/algos/hubbard_model.py index 8c9a450c5d2..c182dc89a0b 100644 --- a/cirq-ft/cirq_ft/algos/hubbard_model.py +++ b/cirq-ft/cirq_ft/algos/hubbard_model.py @@ -335,13 +335,13 @@ def decompose_from_registers( and_target = context.qubit_manager.qalloc(1) and_anc = context.qubit_manager.qalloc(1) yield and_gate.And(cv=(0, 0, 1)).on_registers( - control=[*U, *V, temp[-1]], ancilla=and_anc, target=and_target + ctrl=np.array([U, V, temp[-1:]]), junk=np.array([and_anc]), target=and_target ) yield swap_network.MultiTargetCSwap.make_on( control=and_target, target_x=[*p_x, *p_y, *alpha], target_y=[*q_x, *q_y, *beta] ) yield and_gate.And(cv=(0, 0, 1), adjoint=True).on_registers( - control=[*U, *V, temp[-1]], ancilla=and_anc, target=and_target + ctrl=np.array([U, V, temp[-1:]]), junk=np.array([and_anc]), target=and_target ) context.qubit_manager.qfree([*and_anc, *and_target]) diff --git a/cirq-ft/cirq_ft/algos/multi_control_multi_target_pauli.py b/cirq-ft/cirq_ft/algos/multi_control_multi_target_pauli.py index e2ef065fb7d..f098f0799c1 100644 --- a/cirq-ft/cirq_ft/algos/multi_control_multi_target_pauli.py +++ b/cirq-ft/cirq_ft/algos/multi_control_multi_target_pauli.py @@ -85,15 +85,15 @@ def decompose_from_registers( ) -> cirq.OP_TREE: controls, target = quregs['controls'], quregs['target'] qm = context.qubit_manager - and_ancilla, and_target = qm.qalloc(len(self.cvs) - 2), qm.qalloc(1) + and_ancilla, and_target = np.array(qm.qalloc(len(self.cvs) - 2)), qm.qalloc(1) yield and_gate.And(self.cvs).on_registers( - control=controls, ancilla=and_ancilla, target=and_target + ctrl=controls[:, np.newaxis], junk=and_ancilla[:, np.newaxis], target=and_target ) yield self.target_gate.on(*target).controlled_by(*and_target) yield and_gate.And(self.cvs, adjoint=True).on_registers( - control=controls, ancilla=and_ancilla, target=and_target + ctrl=controls[:, np.newaxis], junk=and_ancilla[:, np.newaxis], target=and_target ) - qm.qfree(and_ancilla + and_target) + qm.qfree([*and_ancilla, *and_target]) def _circuit_diagram_info_(self, _) -> cirq.CircuitDiagramInfo: wire_symbols = ["@" if b else "@(0)" for b in self.cvs] diff --git a/cirq-ft/cirq_ft/algos/prepare_uniform_superposition.py b/cirq-ft/cirq_ft/algos/prepare_uniform_superposition.py index db422d48958..ca7acb297d5 100644 --- a/cirq-ft/cirq_ft/algos/prepare_uniform_superposition.py +++ b/cirq-ft/cirq_ft/algos/prepare_uniform_superposition.py @@ -94,7 +94,9 @@ def decompose_from_registers( and_ancilla = context.qubit_manager.qalloc(len(self.cv) + logL - 2) and_op = and_gate.And((0,) * logL + self.cv).on_registers( - control=[*logL_qubits, *controls], ancilla=and_ancilla, target=ancilla + ctrl=np.asarray([*logL_qubits, *controls])[:, np.newaxis], + junk=np.asarray(and_ancilla)[:, np.newaxis], + target=ancilla, ) yield and_op yield cirq.Rz(rads=theta)(*ancilla) diff --git a/cirq-ft/cirq_ft/algos/qrom.py b/cirq-ft/cirq_ft/algos/qrom.py index 00a02d710cb..ab150b585f5 100644 --- a/cirq-ft/cirq_ft/algos/qrom.py +++ b/cirq-ft/cirq_ft/algos/qrom.py @@ -150,7 +150,9 @@ def decompose_zero_selection( and_ancilla = context.qubit_manager.qalloc(len(controls) - 2) and_target = context.qubit_manager.qalloc(1)[0] multi_controlled_and = and_gate.And((1,) * len(controls)).on_registers( - control=controls, ancilla=and_ancilla, target=and_target + ctrl=np.array(controls)[:, np.newaxis], + junk=np.array(and_ancilla)[:, np.newaxis], + target=and_target, ) yield multi_controlled_and yield self._load_nth_data(zero_indx, lambda q: cirq.CNOT(and_target, q), **target_regs) diff --git a/cirq-ft/cirq_ft/algos/unary_iteration_gate.py b/cirq-ft/cirq_ft/algos/unary_iteration_gate.py index 74375646609..245f57f43fe 100644 --- a/cirq-ft/cirq_ft/algos/unary_iteration_gate.py +++ b/cirq-ft/cirq_ft/algos/unary_iteration_gate.py @@ -153,7 +153,9 @@ def _unary_iteration_multi_controls( and_ancilla = ancilla[: num_controls - 2] and_target = ancilla[num_controls - 2] multi_controlled_and = and_gate.And((1,) * len(controls)).on_registers( - control=np.array(controls), ancilla=np.array(and_ancilla), target=and_target + ctrl=np.array(controls).reshape(len(controls), 1), + junk=np.array(and_ancilla).reshape(len(and_ancilla), 1), + target=and_target, ) ops.append(multi_controlled_and) yield from _unary_iteration_single_control( diff --git a/cirq-ft/cirq_ft/infra/jupyter_tools_test.py b/cirq-ft/cirq_ft/infra/jupyter_tools_test.py index d4ecf100d7f..0fcfe630106 100644 --- a/cirq-ft/cirq_ft/infra/jupyter_tools_test.py +++ b/cirq-ft/cirq_ft/infra/jupyter_tools_test.py @@ -27,7 +27,7 @@ def test_svg_circuit(): svg_str = svg.data # check that the order is respected in the svg data. - assert svg_str.find('control') < svg_str.find('ancilla') < svg_str.find('target') + assert svg_str.find('ctrl') < svg_str.find('junk') < svg_str.find('target') # Check svg_circuit raises. with pytest.raises(ValueError): diff --git a/cirq-ft/cirq_ft/infra/testing_test.py b/cirq-ft/cirq_ft/infra/testing_test.py index 15f0b0a5f06..6cb0fbf391f 100644 --- a/cirq-ft/cirq_ft/infra/testing_test.py +++ b/cirq-ft/cirq_ft/infra/testing_test.py @@ -34,10 +34,16 @@ def test_assert_circuit_inp_out_cirqsim(): def test_gate_helper(): g = cirq_ft.testing.GateHelper(cirq_ft.And(cv=(1, 0, 1, 0))) assert g.gate == cirq_ft.And(cv=(1, 0, 1, 0)) - assert g.r == cirq_ft.Signature.build(control=4, ancilla=2, target=1) + assert g.r == cirq_ft.Signature( + [ + cirq_ft.Register('ctrl', bitsize=1, shape=4), + cirq_ft.Register('junk', bitsize=1, shape=2, side=cirq_ft.infra.Side.RIGHT), + cirq_ft.Register('target', bitsize=1, side=cirq_ft.infra.Side.RIGHT), + ] + ) expected_quregs = { - 'control': cirq.NamedQubit.range(4, prefix='control'), - 'ancilla': cirq.NamedQubit.range(2, prefix='ancilla'), + 'ctrl': np.array([[cirq.q(f'ctrl[{i}]')] for i in range(4)]), + 'junk': np.array([[cirq.q(f'junk[{i}]')] for i in range(2)]), 'target': [cirq.NamedQubit('target')], } for key in expected_quregs: From fd18da5738d8e5d6437f864b1c574de0590fc590 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Emilio=20Pel=C3=A1ez?= <63567458+epelaaez@users.noreply.github.com> Date: Tue, 26 Sep 2023 12:49:07 -0500 Subject: [PATCH 19/19] Make `OpIdentifier` serializable for all inputs (#6295) --- cirq-core/cirq/_compat.py | 3 + cirq-core/cirq/devices/noise_utils.py | 12 ++-- cirq-core/cirq/devices/noise_utils_test.py | 7 +++ .../json_test_data/OpIdentifier.json | 55 +++++++++++++++---- .../json_test_data/OpIdentifier.repr | 8 ++- 5 files changed, 69 insertions(+), 16 deletions(-) diff --git a/cirq-core/cirq/_compat.py b/cirq-core/cirq/_compat.py index 566ec1c2a34..7b0923701d5 100644 --- a/cirq-core/cirq/_compat.py +++ b/cirq-core/cirq/_compat.py @@ -191,6 +191,9 @@ def _print(self, expr, **kwargs): if isinstance(value, Dict): return '{' + ','.join(f"{proper_repr(k)}: {proper_repr(v)}" for k, v in value.items()) + '}' + if hasattr(value, "__qualname__"): + return f"{value.__module__}.{value.__qualname__}" + return repr(value) diff --git a/cirq-core/cirq/devices/noise_utils.py b/cirq-core/cirq/devices/noise_utils.py index 4086ba5c741..7fb30a040ba 100644 --- a/cirq-core/cirq/devices/noise_utils.py +++ b/cirq-core/cirq/devices/noise_utils.py @@ -16,6 +16,7 @@ import numpy as np from cirq import ops, protocols, value +from cirq._compat import proper_repr if TYPE_CHECKING: import cirq @@ -78,20 +79,21 @@ def __str__(self): return f'{self.gate_type}{self.qubits}' def __repr__(self) -> str: - fullname = f'{self.gate_type.__module__}.{self.gate_type.__qualname__}' qubits = ', '.join(map(repr, self.qubits)) - return f'cirq.devices.noise_utils.OpIdentifier({fullname}, {qubits})' + return f'cirq.devices.noise_utils.OpIdentifier({proper_repr(self.gate_type)}, {qubits})' def _value_equality_values_(self) -> Any: return (self.gate_type, self.qubits) def _json_dict_(self) -> Dict[str, Any]: - gate_json = protocols.json_cirq_type(self._gate_type) - return {'gate_type': gate_json, 'qubits': self._qubits} + if hasattr(self.gate_type, '__name__'): + return {'gate_type': protocols.json_cirq_type(self._gate_type), 'qubits': self._qubits} + return {'gate_type': self._gate_type, 'qubits': self._qubits} @classmethod def _from_json_dict_(cls, gate_type, qubits, **kwargs) -> 'OpIdentifier': - gate_type = protocols.cirq_type_from_json(gate_type) + if isinstance(gate_type, str): + gate_type = protocols.cirq_type_from_json(gate_type) return cls(gate_type, *qubits) diff --git a/cirq-core/cirq/devices/noise_utils_test.py b/cirq-core/cirq/devices/noise_utils_test.py index 412ff1fa192..df9833424ec 100644 --- a/cirq-core/cirq/devices/noise_utils_test.py +++ b/cirq-core/cirq/devices/noise_utils_test.py @@ -62,6 +62,13 @@ def test_op_id_swap(): assert cirq.CZ(q1, q0) in swap_id +def test_op_id_instance(): + q0 = cirq.LineQubit.range(1)[0] + gate = cirq.SingleQubitCliffordGate.from_xz_map((cirq.X, False), (cirq.Z, False)) + op_id = OpIdentifier(gate, q0) + cirq.testing.assert_equivalent_repr(op_id) + + @pytest.mark.parametrize( 'decay_constant,num_qubits,expected_output', [(0.01, 1, 1 - (0.99 * 1 / 2)), (0.05, 2, 1 - (0.95 * 3 / 4))], diff --git a/cirq-core/cirq/protocols/json_test_data/OpIdentifier.json b/cirq-core/cirq/protocols/json_test_data/OpIdentifier.json index d33b909367d..e9e3f35d109 100644 --- a/cirq-core/cirq/protocols/json_test_data/OpIdentifier.json +++ b/cirq-core/cirq/protocols/json_test_data/OpIdentifier.json @@ -1,10 +1,45 @@ -{ - "cirq_type": "OpIdentifier", - "gate_type": "XPowGate", - "qubits": [ - { - "cirq_type": "LineQubit", - "x": 1 - } - ] -} \ No newline at end of file +[ + { + "cirq_type": "OpIdentifier", + "gate_type": "XPowGate", + "qubits": [ + { + "cirq_type": "LineQubit", + "x": 1 + } + ] + }, + { + "cirq_type": "OpIdentifier", + "gate_type": { + "cirq_type": "CliffordGate", + "n": 1, + "rs": [ + false, + false + ], + "xs": [ + [ + true + ], + [ + false + ] + ], + "zs": [ + [ + false + ], + [ + true + ] + ] + }, + "qubits": [ + { + "cirq_type": "LineQubit", + "x": 0 + } + ] + } +] \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/OpIdentifier.repr b/cirq-core/cirq/protocols/json_test_data/OpIdentifier.repr index 6b991bb0b2c..010a806a9e2 100644 --- a/cirq-core/cirq/protocols/json_test_data/OpIdentifier.repr +++ b/cirq-core/cirq/protocols/json_test_data/OpIdentifier.repr @@ -1,4 +1,10 @@ +[ cirq.devices.noise_utils.OpIdentifier( cirq.ops.common_gates.XPowGate, cirq.LineQubit(1) -) \ No newline at end of file +), +cirq.devices.noise_utils.OpIdentifier( + cirq.CliffordGate.from_clifford_tableau(cirq.CliffordTableau(1,rs=np.array([False, False],dtype=np.dtype('bool')), xs=np.array([[True], [False]], dtype=np.dtype('bool')),zs=np.array([[False], [True]], dtype=np.dtype('bool')), initial_state=0)), + cirq.LineQubit(0) +) +] \ No newline at end of file