Skip to content

Commit

Permalink
Replace fields with properties in ActOnArgs (quantumlib#5011)
Browse files Browse the repository at this point in the history
  • Loading branch information
daxfohl authored Feb 22, 2022
1 parent 5bdaf01 commit cfe4339
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 17 deletions.
44 changes: 34 additions & 10 deletions cirq/sim/act_on_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,23 @@
import inspect
from typing import (
Any,
cast,
Dict,
Iterator,
List,
Mapping,
Optional,
Sequence,
TypeVar,
TYPE_CHECKING,
Sequence,
Tuple,
cast,
Optional,
Iterator,
)
import warnings

import numpy as np

from cirq import ops, protocols, value
from cirq._compat import deprecated
from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits
from cirq.sim.operation_target import OperationTarget

Expand Down Expand Up @@ -74,7 +76,7 @@ def __init__(
if qubits is None:
qubits = ()
self._set_qubits(qubits)
self.prng = prng
self._prng = prng
self._classical_data = classical_data or value.ClassicalDataDictionaryStore(
_records={
value.MeasurementKey.parse_serialized(k): [tuple(v)]
Expand All @@ -83,9 +85,33 @@ def __init__(
)
self._ignore_measurement_results = ignore_measurement_results

@property
def prng(self) -> np.random.RandomState:
return self._prng

@property
def qubit_map(self) -> Mapping['cirq.Qid', int]:
return self._qubit_map

@prng.setter # type: ignore
@deprecated(
deadline="v0.15",
fix="The mutators of this class are deprecated, instantiate a new object instead.",
)
def prng(self, prng):
self._prng = prng

@qubit_map.setter # type: ignore
@deprecated(
deadline="v0.15",
fix="The mutators of this class are deprecated, instantiate a new object instead.",
)
def qubit_map(self, qubit_map):
self._qubit_map = qubit_map

def _set_qubits(self, qubits: Sequence['cirq.Qid']):
self._qubits = tuple(qubits)
self.qubit_map = {q: i for i, q in enumerate(self.qubits)}
self._qubit_map = {q: i for i, q in enumerate(self.qubits)}

def measure(self, qubits: Sequence['cirq.Qid'], key: str, invert_mask: Sequence[bool]):
"""Measures the qubits and records to `log_of_measurement_results`.
Expand Down Expand Up @@ -281,8 +307,7 @@ def swap(self, q1: 'cirq.Qid', q2: 'cirq.Qid', *, inplace=False):
i2 = self.qubits.index(q2)
qubits = list(args.qubits)
qubits[i1], qubits[i2] = qubits[i2], qubits[i1]
args._qubits = tuple(qubits)
args.qubit_map = {q: i for i, q in enumerate(qubits)}
args._set_qubits(qubits)
return args

def rename(self, q1: 'cirq.Qid', q2: 'cirq.Qid', *, inplace=False):
Expand All @@ -309,8 +334,7 @@ def rename(self, q1: 'cirq.Qid', q2: 'cirq.Qid', *, inplace=False):
i1 = self.qubits.index(q1)
qubits = list(args.qubits)
qubits[i1] = q2
args._qubits = tuple(qubits)
args.qubit_map = {q: i for i, q in enumerate(qubits)}
args._set_qubits(qubits)
return args

def __getitem__(self: TSelf, item: Optional['cirq.Qid']) -> TSelf:
Expand Down
40 changes: 33 additions & 7 deletions cirq/sim/act_on_args_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Generic,
Iterator,
List,
Mapping,
Optional,
Sequence,
Tuple,
Expand All @@ -30,6 +31,7 @@
import numpy as np

from cirq import ops, protocols, value
from cirq._compat import deprecated
from cirq.sim.operation_target import OperationTarget
from cirq.sim.simulator import (
TActOnArgs,
Expand Down Expand Up @@ -68,16 +70,40 @@ def __init__(
classical_data: The shared classical data container for this
simulation.
"""
self.args = args
self._args = args
self._qubits = tuple(qubits)
self.split_untangled_states = split_untangled_states
self._split_untangled_states = split_untangled_states
self._classical_data = classical_data or value.ClassicalDataDictionaryStore(
_records={
value.MeasurementKey.parse_serialized(k): [tuple(v)]
for k, v in (log_of_measurement_results or {}).items()
}
)

@property
def args(self) -> Mapping[Optional['cirq.Qid'], TActOnArgs]:
return self._args

@property
def split_untangled_states(self) -> bool:
return self._split_untangled_states

@args.setter # type: ignore
@deprecated(
deadline="v0.15",
fix="The mutators of this class are deprecated, instantiate a new object instead.",
)
def args(self, args):
self._args = args

@split_untangled_states.setter # type: ignore
@deprecated(
deadline="v0.15",
fix="The mutators of this class are deprecated, instantiate a new object instead.",
)
def split_untangled_states(self, split_untangled_states):
self._split_untangled_states = split_untangled_states

def create_merged_state(self) -> TActOnArgs:
if not self.split_untangled_states:
return self.args[None]
Expand All @@ -104,8 +130,8 @@ def _act_on_fallback_(
if args0 is args1:
args0.swap(q0, q1, inplace=True)
else:
self.args[q0] = args1.rename(q1, q0, inplace=True)
self.args[q1] = args0.rename(q0, q1, inplace=True)
self._args[q0] = args1.rename(q1, q0, inplace=True)
self._args[q1] = args0.rename(q0, q1, inplace=True)
return True

# Go through the op's qubits and join any disparate ActOnArgs states
Expand All @@ -120,7 +146,7 @@ def _act_on_fallback_(

# (Backfill the args map with the new value)
for q in op_args.qubits:
self.args[q] = op_args
self._args[q] = op_args

# Act on the args with the operation
act_on_qubits = qubits if isinstance(action, ops.Gate) else None
Expand All @@ -134,11 +160,11 @@ def _act_on_fallback_(
for q in qubits:
if op_args.allows_factoring:
q_args, op_args = op_args.factor((q,), validate=False)
self.args[q] = q_args
self._args[q] = q_args

# (Backfill the args map with the new value)
for q in op_args.qubits:
self.args[q] = op_args
self._args[q] = op_args
return True

def copy(self, deep_copy_buffers: bool = True) -> 'cirq.ActOnArgsContainer[TActOnArgs]':
Expand Down
14 changes: 14 additions & 0 deletions cirq/sim/act_on_args_container_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,3 +282,17 @@ def test_act_on_gate_does_not_join():
assert len(set(args.values())) == 3
assert args[q0] is not args[q1]
assert args[q0] is not args[None]


def test_field_getters():
args = create_container(qs2)
assert args.args.keys() == set(qs2) | {None}
assert args.split_untangled_states


def test_field_setters_deprecated():
args = create_container(qs2)
with cirq.testing.assert_deprecated(deadline='v0.15'):
args.args = {}
with cirq.testing.assert_deprecated(deadline='v0.15'):
args.split_untangled_states = False
15 changes: 15 additions & 0 deletions cirq/sim/act_on_args_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from typing import Sequence, Union

import numpy as np
import pytest

import cirq
Expand Down Expand Up @@ -98,3 +99,17 @@ def test_on_copy_has_no_param():
args = DummyArgs()
with cirq.testing.assert_deprecated('deep_copy_buffers', deadline='0.15'):
args.copy(False)


def test_field_getters():
args = DummyArgs()
assert args.prng is np.random
assert args.qubit_map == {q: i for i, q in enumerate(cirq.LineQubit.range(2))}


def test_field_setters_deprecated():
args = DummyArgs()
with cirq.testing.assert_deprecated(deadline='v0.15'):
args.prng = 0
with cirq.testing.assert_deprecated(deadline='v0.15'):
args.qubit_map = {}

0 comments on commit cfe4339

Please sign in to comment.