diff --git a/cirq-core/cirq/circuits/circuit_operation.py b/cirq-core/cirq/circuits/circuit_operation.py index c58687e217f..534e45e88ea 100644 --- a/cirq-core/cirq/circuits/circuit_operation.py +++ b/cirq-core/cirq/circuits/circuit_operation.py @@ -90,6 +90,10 @@ class CircuitOperation(ops.Operation): targets for unbound `ClassicallyControlledOperation` keys. This field is not intended to be set or changed manually, and should be empty in circuits that aren't in the middle of decomposition. + use_repetition_ids: When True, any measurement key in the subcircuit + will have its path prepended with the repetition id for each + repetition. When False, this will not happen and the measurement + key will be repeated. """ _hash: Optional[int] = dataclasses.field(default=None, init=False) @@ -108,6 +112,7 @@ class CircuitOperation(ops.Operation): repetition_ids: Optional[List[str]] = dataclasses.field(default=None) parent_path: Tuple[str, ...] = dataclasses.field(default_factory=tuple) extern_keys: FrozenSet['cirq.MeasurementKey'] = dataclasses.field(default_factory=frozenset) + use_repetition_ids: bool = True def __post_init__(self): if not isinstance(self.circuit, circuits.FrozenCircuit): @@ -168,6 +173,7 @@ def __eq__(self, other) -> bool: and self.repetitions == other.repetitions and self.repetition_ids == other.repetition_ids and self.parent_path == other.parent_path + and self.use_repetition_ids == other.use_repetition_ids ) # Methods for getting post-mapping properties of the contained circuit. @@ -190,7 +196,7 @@ def _is_measurement_(self) -> bool: def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']: if self._cached_measurement_key_objs is None: circuit_keys = protocols.measurement_key_objs(self.circuit) - if self.repetition_ids is not None: + if self.repetition_ids is not None and self.use_repetition_ids: circuit_keys = { key.with_key_path_prefix(repetition_id) for repetition_id in self.repetition_ids @@ -251,7 +257,7 @@ def mapped_circuit(self, deep: bool = False) -> 'cirq.Circuit': if self.param_resolver: circuit = protocols.resolve_parameters(circuit, self.param_resolver, recursive=False) if self.repetition_ids: - if not protocols.is_measurement(circuit): + if not self.use_repetition_ids or not protocols.is_measurement(circuit): circuit = circuit * abs(self.repetitions) else: circuit = circuits.Circuit( @@ -295,6 +301,8 @@ def __repr__(self): if self.repetition_ids != self._default_repetition_ids(): # Default repetition_ids need not be specified. args += f'repetition_ids={proper_repr(self.repetition_ids)},\n' + if not self.use_repetition_ids: + args += 'use_repetition_ids=False,\n' indented_args = args.replace('\n', '\n ') return f'cirq.CircuitOperation({indented_args[:-4]})' @@ -325,6 +333,8 @@ def dict_str(d: Dict) -> str: elif self.repetitions != 1: # Only add loops if we haven't added repetition_ids. args.append(f'loops={self.repetitions}') + if not self.use_repetition_ids: + args.append('no_rep_ids') if not args: return circuit_msg return f'{circuit_msg}({", ".join(args)})' @@ -343,13 +353,14 @@ def __hash__(self): self.param_resolver, self.parent_path, tuple([] if self.repetition_ids is None else self.repetition_ids), + self.use_repetition_ids, ) ), ) return self._hash def _json_dict_(self): - return { + resp = { 'circuit': self.circuit, 'repetitions': self.repetitions, # JSON requires mappings to have keys of basic types. @@ -360,6 +371,9 @@ def _json_dict_(self): 'repetition_ids': self.repetition_ids, 'parent_path': self.parent_path, } + if not self.use_repetition_ids: + resp['use_repetition_ids'] = False + return resp @classmethod def _from_json_dict_( @@ -371,10 +385,11 @@ def _from_json_dict_( param_resolver, repetition_ids, parent_path=(), + use_repetition_ids=True, **kwargs, ): return ( - cls(circuit) + cls(circuit, use_repetition_ids=use_repetition_ids) .with_qubit_mapping(dict(qubit_map)) .with_measurement_key_mapping(measurement_key_map) .with_params(param_resolver) diff --git a/cirq-core/cirq/circuits/circuit_operation_test.py b/cirq-core/cirq/circuits/circuit_operation_test.py index 7fbf782ef9b..087d02f4316 100644 --- a/cirq-core/cirq/circuits/circuit_operation_test.py +++ b/cirq-core/cirq/circuits/circuit_operation_test.py @@ -20,6 +20,13 @@ from cirq.circuits.circuit_operation import _full_join_string_lists +ALL_SIMULATORS = ( + cirq.Simulator(), + cirq.DensityMatrixSimulator(), + cirq.CliffordSimulator(), +) + + def test_properties(): a, b, c = cirq.LineQubit.range(3) circuit = cirq.FrozenCircuit( @@ -457,6 +464,26 @@ def test_string_format(): ]), )""" ) + op6 = cirq.CircuitOperation(fc5, use_repetition_ids=False) + assert ( + repr(op6) + == """\ +cirq.CircuitOperation( + circuit=cirq.FrozenCircuit([ + cirq.Moment( + cirq.X(cirq.LineQubit(0)), + cirq.CircuitOperation( + circuit=cirq.FrozenCircuit([ + cirq.Moment( + cirq.X(cirq.LineQubit(1)), + ), + ]), + ), + ), + ]), + use_repetition_ids=False, +)""" + ) def test_json_dict(): @@ -858,4 +885,47 @@ def test_mapped_circuit_allows_repeated_keys(): ) +@pytest.mark.parametrize('sim', ALL_SIMULATORS) +def test_simulate_no_repetition_ids_both_levels(sim): + q = cirq.LineQubit(0) + inner = cirq.Circuit(cirq.measure(q, key='a')) + middle = cirq.Circuit( + cirq.CircuitOperation(inner.freeze(), repetitions=2, use_repetition_ids=False) + ) + outer_subcircuit = cirq.CircuitOperation( + middle.freeze(), repetitions=2, use_repetition_ids=False + ) + circuit = cirq.Circuit(outer_subcircuit) + result = sim.run(circuit) + assert result.records['a'].shape == (1, 4, 1) + + +@pytest.mark.parametrize('sim', ALL_SIMULATORS) +def test_simulate_no_repetition_ids_outer(sim): + q = cirq.LineQubit(0) + inner = cirq.Circuit(cirq.measure(q, key='a')) + middle = cirq.Circuit(cirq.CircuitOperation(inner.freeze(), repetitions=2)) + outer_subcircuit = cirq.CircuitOperation( + middle.freeze(), repetitions=2, use_repetition_ids=False + ) + circuit = cirq.Circuit(outer_subcircuit) + result = sim.run(circuit) + assert result.records['0:a'].shape == (1, 2, 1) + assert result.records['1:a'].shape == (1, 2, 1) + + +@pytest.mark.parametrize('sim', ALL_SIMULATORS) +def test_simulate_no_repetition_ids_inner(sim): + q = cirq.LineQubit(0) + inner = cirq.Circuit(cirq.measure(q, key='a')) + middle = cirq.Circuit( + cirq.CircuitOperation(inner.freeze(), repetitions=2, use_repetition_ids=False) + ) + outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2) + circuit = cirq.Circuit(outer_subcircuit) + result = sim.run(circuit) + assert result.records['0:a'].shape == (1, 2, 1) + assert result.records['1:a'].shape == (1, 2, 1) + + # TODO: Operation has a "gate" property. What is this for a CircuitOperation? diff --git a/cirq-core/cirq/ops/classically_controlled_operation_test.py b/cirq-core/cirq/ops/classically_controlled_operation_test.py index 0daf9f327e7..fc09d939eee 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation_test.py +++ b/cirq-core/cirq/ops/classically_controlled_operation_test.py @@ -482,6 +482,123 @@ def test_scope_local(): assert circuit == cirq.Circuit(cirq.decompose(outer_subcircuit)) +def test_scope_flatten_both(): + q = cirq.LineQubit(0) + inner = cirq.Circuit( + cirq.measure(q, key='a'), + cirq.X(q).with_classical_controls('a'), + ) + middle = cirq.Circuit( + cirq.CircuitOperation(inner.freeze(), repetitions=2, use_repetition_ids=False) + ) + outer_subcircuit = cirq.CircuitOperation( + middle.freeze(), repetitions=2, use_repetition_ids=False + ) + circuit = outer_subcircuit.mapped_circuit(deep=True) + internal_control_keys = [ + str(condition) for op in circuit.all_operations() for condition in cirq.control_keys(op) + ] + assert internal_control_keys == ['a', 'a', 'a', 'a'] + assert not cirq.control_keys(outer_subcircuit) + assert not cirq.control_keys(circuit) + cirq.testing.assert_has_diagram( + cirq.Circuit(outer_subcircuit), + """ + [ [ 0: ───M───X─── ] ] +0: ───[ 0: ───[ ║ ║ ]──────────────────────── ]──────────────────────── + [ [ a: ═══@═══^═══ ](loops=2, no_rep_ids) ](loops=2, no_rep_ids) +""", + use_unicode_characters=True, + ) + cirq.testing.assert_has_diagram( + circuit, + """ +0: ───M───X───M───X───M───X───M───X─── + ║ ║ ║ ║ ║ ║ ║ ║ +a: ═══@═══^═══@═══^═══@═══^═══@═══^═══ +""", + use_unicode_characters=True, + ) + + +def test_scope_flatten_inner(): + q = cirq.LineQubit(0) + inner = cirq.Circuit( + cirq.measure(q, key='a'), + cirq.X(q).with_classical_controls('a'), + ) + middle = cirq.Circuit( + cirq.CircuitOperation(inner.freeze(), repetitions=2, use_repetition_ids=False) + ) + outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2) + circuit = outer_subcircuit.mapped_circuit(deep=True) + internal_control_keys = [ + str(condition) for op in circuit.all_operations() for condition in cirq.control_keys(op) + ] + assert internal_control_keys == ['0:a', '0:a', '1:a', '1:a'] + assert not cirq.control_keys(outer_subcircuit) + assert not cirq.control_keys(circuit) + cirq.testing.assert_has_diagram( + cirq.Circuit(outer_subcircuit), + """ + [ [ 0: ───M───X─── ] ] +0: ───[ 0: ───[ ║ ║ ]──────────────────────── ]──────────── + [ [ a: ═══@═══^═══ ](loops=2, no_rep_ids) ](loops=2) +""", + use_unicode_characters=True, + ) + cirq.testing.assert_has_diagram( + circuit, + """ +0: ─────M───X───M───X───M───X───M───X─── + ║ ║ ║ ║ ║ ║ ║ ║ +0:a: ═══@═══^═══@═══^═══╬═══╬═══╬═══╬═══ + ║ ║ ║ ║ +1:a: ═══════════════════@═══^═══@═══^═══ +""", + use_unicode_characters=True, + ) + + +def test_scope_flatten_outer(): + q = cirq.LineQubit(0) + inner = cirq.Circuit( + cirq.measure(q, key='a'), + cirq.X(q).with_classical_controls('a'), + ) + middle = cirq.Circuit(cirq.CircuitOperation(inner.freeze(), repetitions=2)) + outer_subcircuit = cirq.CircuitOperation( + middle.freeze(), repetitions=2, use_repetition_ids=False + ) + circuit = outer_subcircuit.mapped_circuit(deep=True) + internal_control_keys = [ + str(condition) for op in circuit.all_operations() for condition in cirq.control_keys(op) + ] + assert internal_control_keys == ['0:a', '1:a', '0:a', '1:a'] + assert not cirq.control_keys(outer_subcircuit) + assert not cirq.control_keys(circuit) + cirq.testing.assert_has_diagram( + cirq.Circuit(outer_subcircuit), + """ + [ [ 0: ───M───X─── ] ] +0: ───[ 0: ───[ ║ ║ ]──────────── ]──────────────────────── + [ [ a: ═══@═══^═══ ](loops=2) ](loops=2, no_rep_ids) +""", + use_unicode_characters=True, + ) + cirq.testing.assert_has_diagram( + circuit, + """ +0: ─────M───X───M───X───M───X───M───X─── + ║ ║ ║ ║ ║ ║ ║ ║ +0:a: ═══@═══^═══╬═══╬═══@═══^═══╬═══╬═══ + ║ ║ ║ ║ +1:a: ═══════════@═══^═══════════@═══^═══ +""", + use_unicode_characters=True, + ) + + def test_scope_extern(): q = cirq.LineQubit(0) inner = cirq.Circuit( diff --git a/cirq-core/cirq/protocols/json_test_data/CircuitOperation.json b/cirq-core/cirq/protocols/json_test_data/CircuitOperation.json index 73b77c264db..fe87b08ad67 100644 --- a/cirq-core/cirq/protocols/json_test_data/CircuitOperation.json +++ b/cirq-core/cirq/protocols/json_test_data/CircuitOperation.json @@ -292,7 +292,8 @@ ] }, "parent_path": [], - "repetition_ids": null + "repetition_ids": null, + "use_repetition_ids": false } ] ] diff --git a/cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr b/cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr index 05010e33fa3..791ebda07e8 100644 --- a/cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr +++ b/cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr @@ -28,10 +28,12 @@ cirq.CircuitOperation(circuit=cirq.FrozenCircuit([ repetitions=-2, parent_path=('outer', 'inner'), repetition_ids=['a', 'b'], -qubit_map={cirq.LineQubit(0): cirq.LineQubit(1)}), +qubit_map={cirq.LineQubit(0): cirq.LineQubit(1)}, +use_repetition_ids=True), cirq.CircuitOperation(circuit=cirq.FrozenCircuit([ cirq.Moment( (cirq.X**sympy.Symbol('theta')).on(cirq.LineQubit(0)), ), ]), -param_resolver={sympy.Symbol('theta'): 1.5})] \ No newline at end of file +param_resolver={sympy.Symbol('theta'): 1.5}, +use_repetition_ids=False)] \ No newline at end of file