From cfe4339b62b24d8d192d4f26ef5683cd85dea2f9 Mon Sep 17 00:00:00 2001 From: Dax Fohl Date: Tue, 22 Feb 2022 15:00:35 -0800 Subject: [PATCH] Replace fields with properties in ActOnArgs (#5011) xref #4851 @95-martin-orion --- cirq/sim/act_on_args.py | 44 ++++++++++++++++++++------ cirq/sim/act_on_args_container.py | 40 +++++++++++++++++++---- cirq/sim/act_on_args_container_test.py | 14 ++++++++ cirq/sim/act_on_args_test.py | 15 +++++++++ 4 files changed, 96 insertions(+), 17 deletions(-) diff --git a/cirq/sim/act_on_args.py b/cirq/sim/act_on_args.py index bb656f8a3c7..0ffd7734ff6 100644 --- a/cirq/sim/act_on_args.py +++ b/cirq/sim/act_on_args.py @@ -17,21 +17,23 @@ import inspect from typing import ( Any, + cast, Dict, + Iterator, List, + Mapping, + Optional, + Sequence, TypeVar, TYPE_CHECKING, - Sequence, Tuple, - cast, - Optional, - Iterator, ) import warnings import numpy as np from cirq import ops, protocols, value +from cirq._compat import deprecated from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits from cirq.sim.operation_target import OperationTarget @@ -74,7 +76,7 @@ def __init__( if qubits is None: qubits = () self._set_qubits(qubits) - self.prng = prng + self._prng = prng self._classical_data = classical_data or value.ClassicalDataDictionaryStore( _records={ value.MeasurementKey.parse_serialized(k): [tuple(v)] @@ -83,9 +85,33 @@ def __init__( ) self._ignore_measurement_results = ignore_measurement_results + @property + def prng(self) -> np.random.RandomState: + return self._prng + + @property + def qubit_map(self) -> Mapping['cirq.Qid', int]: + return self._qubit_map + + @prng.setter # type: ignore + @deprecated( + deadline="v0.15", + fix="The mutators of this class are deprecated, instantiate a new object instead.", + ) + def prng(self, prng): + self._prng = prng + + @qubit_map.setter # type: ignore + @deprecated( + deadline="v0.15", + fix="The mutators of this class are deprecated, instantiate a new object instead.", + ) + def qubit_map(self, qubit_map): + self._qubit_map = qubit_map + def _set_qubits(self, qubits: Sequence['cirq.Qid']): self._qubits = tuple(qubits) - self.qubit_map = {q: i for i, q in enumerate(self.qubits)} + self._qubit_map = {q: i for i, q in enumerate(self.qubits)} def measure(self, qubits: Sequence['cirq.Qid'], key: str, invert_mask: Sequence[bool]): """Measures the qubits and records to `log_of_measurement_results`. @@ -281,8 +307,7 @@ def swap(self, q1: 'cirq.Qid', q2: 'cirq.Qid', *, inplace=False): i2 = self.qubits.index(q2) qubits = list(args.qubits) qubits[i1], qubits[i2] = qubits[i2], qubits[i1] - args._qubits = tuple(qubits) - args.qubit_map = {q: i for i, q in enumerate(qubits)} + args._set_qubits(qubits) return args def rename(self, q1: 'cirq.Qid', q2: 'cirq.Qid', *, inplace=False): @@ -309,8 +334,7 @@ def rename(self, q1: 'cirq.Qid', q2: 'cirq.Qid', *, inplace=False): i1 = self.qubits.index(q1) qubits = list(args.qubits) qubits[i1] = q2 - args._qubits = tuple(qubits) - args.qubit_map = {q: i for i, q in enumerate(qubits)} + args._set_qubits(qubits) return args def __getitem__(self: TSelf, item: Optional['cirq.Qid']) -> TSelf: diff --git a/cirq/sim/act_on_args_container.py b/cirq/sim/act_on_args_container.py index f450ad9bbec..1ee6d24acb0 100644 --- a/cirq/sim/act_on_args_container.py +++ b/cirq/sim/act_on_args_container.py @@ -20,6 +20,7 @@ Generic, Iterator, List, + Mapping, Optional, Sequence, Tuple, @@ -30,6 +31,7 @@ import numpy as np from cirq import ops, protocols, value +from cirq._compat import deprecated from cirq.sim.operation_target import OperationTarget from cirq.sim.simulator import ( TActOnArgs, @@ -68,9 +70,9 @@ def __init__( classical_data: The shared classical data container for this simulation. """ - self.args = args + self._args = args self._qubits = tuple(qubits) - self.split_untangled_states = split_untangled_states + self._split_untangled_states = split_untangled_states self._classical_data = classical_data or value.ClassicalDataDictionaryStore( _records={ value.MeasurementKey.parse_serialized(k): [tuple(v)] @@ -78,6 +80,30 @@ def __init__( } ) + @property + def args(self) -> Mapping[Optional['cirq.Qid'], TActOnArgs]: + return self._args + + @property + def split_untangled_states(self) -> bool: + return self._split_untangled_states + + @args.setter # type: ignore + @deprecated( + deadline="v0.15", + fix="The mutators of this class are deprecated, instantiate a new object instead.", + ) + def args(self, args): + self._args = args + + @split_untangled_states.setter # type: ignore + @deprecated( + deadline="v0.15", + fix="The mutators of this class are deprecated, instantiate a new object instead.", + ) + def split_untangled_states(self, split_untangled_states): + self._split_untangled_states = split_untangled_states + def create_merged_state(self) -> TActOnArgs: if not self.split_untangled_states: return self.args[None] @@ -104,8 +130,8 @@ def _act_on_fallback_( if args0 is args1: args0.swap(q0, q1, inplace=True) else: - self.args[q0] = args1.rename(q1, q0, inplace=True) - self.args[q1] = args0.rename(q0, q1, inplace=True) + self._args[q0] = args1.rename(q1, q0, inplace=True) + self._args[q1] = args0.rename(q0, q1, inplace=True) return True # Go through the op's qubits and join any disparate ActOnArgs states @@ -120,7 +146,7 @@ def _act_on_fallback_( # (Backfill the args map with the new value) for q in op_args.qubits: - self.args[q] = op_args + self._args[q] = op_args # Act on the args with the operation act_on_qubits = qubits if isinstance(action, ops.Gate) else None @@ -134,11 +160,11 @@ def _act_on_fallback_( for q in qubits: if op_args.allows_factoring: q_args, op_args = op_args.factor((q,), validate=False) - self.args[q] = q_args + self._args[q] = q_args # (Backfill the args map with the new value) for q in op_args.qubits: - self.args[q] = op_args + self._args[q] = op_args return True def copy(self, deep_copy_buffers: bool = True) -> 'cirq.ActOnArgsContainer[TActOnArgs]': diff --git a/cirq/sim/act_on_args_container_test.py b/cirq/sim/act_on_args_container_test.py index 597849f4921..384af341ee3 100644 --- a/cirq/sim/act_on_args_container_test.py +++ b/cirq/sim/act_on_args_container_test.py @@ -282,3 +282,17 @@ def test_act_on_gate_does_not_join(): assert len(set(args.values())) == 3 assert args[q0] is not args[q1] assert args[q0] is not args[None] + + +def test_field_getters(): + args = create_container(qs2) + assert args.args.keys() == set(qs2) | {None} + assert args.split_untangled_states + + +def test_field_setters_deprecated(): + args = create_container(qs2) + with cirq.testing.assert_deprecated(deadline='v0.15'): + args.args = {} + with cirq.testing.assert_deprecated(deadline='v0.15'): + args.split_untangled_states = False diff --git a/cirq/sim/act_on_args_test.py b/cirq/sim/act_on_args_test.py index 8b378f398a7..8852777ef6f 100644 --- a/cirq/sim/act_on_args_test.py +++ b/cirq/sim/act_on_args_test.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Sequence, Union +import numpy as np import pytest import cirq @@ -98,3 +99,17 @@ def test_on_copy_has_no_param(): args = DummyArgs() with cirq.testing.assert_deprecated('deep_copy_buffers', deadline='0.15'): args.copy(False) + + +def test_field_getters(): + args = DummyArgs() + assert args.prng is np.random + assert args.qubit_map == {q: i for i, q in enumerate(cirq.LineQubit.range(2))} + + +def test_field_setters_deprecated(): + args = DummyArgs() + with cirq.testing.assert_deprecated(deadline='v0.15'): + args.prng = 0 + with cirq.testing.assert_deprecated(deadline='v0.15'): + args.qubit_map = {}